topk_softmax_kernels.cu 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. /*
  2. * Adapted from
  3. * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
  4. * Copyright (c) 2024, The vLLM team.
  5. * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
  6. * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
  7. *
  8. * Licensed under the Apache License, Version 2.0 (the "License");
  9. * you may not use this file except in compliance with the License.
  10. * You may obtain a copy of the License at
  11. *
  12. * http://www.apache.org/licenses/LICENSE-2.0
  13. *
  14. * Unless required by applicable law or agreed to in writing, software
  15. * distributed under the License is distributed on an "AS IS" BASIS,
  16. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. * See the License for the specific language governing permissions and
  18. * limitations under the License.
  19. */
  20. #include <torch/all.h>
  21. #include <ATen/cuda/CUDAContext.h>
  22. #include <c10/cuda/CUDAGuard.h>
  23. #include "../cuda_compat.h"
  24. #ifndef USE_ROCM
  25. #include <cub/util_type.cuh>
  26. #include <cub/cub.cuh>
  27. #else
  28. #include <hipcub/util_type.hpp>
  29. #include <hipcub/hipcub.hpp>
  30. #endif
  31. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  32. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  33. namespace aphrodite {
  34. namespace moe {
  35. /// Aligned array type
  36. template <typename T,
  37. /// Number of elements in the array
  38. int N,
  39. /// Alignment requirement in bytes
  40. int Alignment = sizeof(T) * N>
  41. class alignas(Alignment) AlignedArray {
  42. float data[N];
  43. };
  44. // ====================== Softmax things ===============================
  45. // We have our own implementation of softmax here so we can support transposing
  46. // the output in the softmax kernel when we extend this module to support
  47. // expert-choice routing.
  48. template <int TPB>
  49. __launch_bounds__(TPB) __global__
  50. void moeSoftmax(const float* input, const bool* finished, float* output,
  51. const int num_cols) {
  52. using BlockReduce = cub::BlockReduce<float, TPB>;
  53. __shared__ typename BlockReduce::TempStorage tmpStorage;
  54. __shared__ float normalizing_factor;
  55. __shared__ float float_max;
  56. const int thread_row_offset = blockIdx.x * num_cols;
  57. cub::Sum sum;
  58. float threadData(-FLT_MAX);
  59. // Don't touch finished rows.
  60. if ((finished != nullptr) && finished[blockIdx.x]) {
  61. return;
  62. }
  63. for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
  64. const int idx = thread_row_offset + ii;
  65. threadData = max(static_cast<float>(input[idx]), threadData);
  66. }
  67. const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
  68. if (threadIdx.x == 0) {
  69. float_max = maxElem;
  70. }
  71. __syncthreads();
  72. threadData = 0;
  73. for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
  74. const int idx = thread_row_offset + ii;
  75. threadData += exp((static_cast<float>(input[idx]) - float_max));
  76. }
  77. const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
  78. if (threadIdx.x == 0) {
  79. normalizing_factor = 1.f / Z;
  80. }
  81. __syncthreads();
  82. for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
  83. const int idx = thread_row_offset + ii;
  84. const float val =
  85. exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
  86. output[idx] = val;
  87. }
  88. }
  89. template <int TPB>
  90. __launch_bounds__(TPB) __global__
  91. void moeTopK(const float* inputs_after_softmax, const bool* finished,
  92. float* output, int* indices, int* source_rows,
  93. const int num_experts, const int k, const int start_expert,
  94. const int end_expert) {
  95. using cub_kvp = cub::KeyValuePair<int, float>;
  96. using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
  97. __shared__ typename BlockReduce::TempStorage tmpStorage;
  98. cub_kvp thread_kvp;
  99. cub::ArgMax arg_max;
  100. const int num_rows = gridDim.x;
  101. const int block_row = blockIdx.x;
  102. const bool row_is_active = finished ? !finished[block_row] : true;
  103. const int thread_read_offset = blockIdx.x * num_experts;
  104. for (int k_idx = 0; k_idx < k; ++k_idx) {
  105. thread_kvp.key = 0;
  106. thread_kvp.value = -1.f; // This is OK because inputs are probabilities
  107. cub_kvp inp_kvp;
  108. for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
  109. const int idx = thread_read_offset + expert;
  110. inp_kvp.key = expert;
  111. inp_kvp.value = inputs_after_softmax[idx];
  112. for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
  113. const int prior_winning_expert = indices[k * block_row + prior_k];
  114. if (prior_winning_expert == expert) {
  115. inp_kvp = thread_kvp;
  116. }
  117. }
  118. thread_kvp = arg_max(inp_kvp, thread_kvp);
  119. }
  120. const cub_kvp result_kvp =
  121. BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
  122. if (threadIdx.x == 0) {
  123. // Ignore experts the node isn't responsible for with expert parallelism
  124. const int expert = result_kvp.key;
  125. const bool node_uses_expert =
  126. expert >= start_expert && expert < end_expert;
  127. const bool should_process_row = row_is_active && node_uses_expert;
  128. const int idx = k * block_row + k_idx;
  129. output[idx] = result_kvp.value;
  130. indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
  131. assert(indices[idx] >= 0);
  132. source_rows[idx] = k_idx * num_rows + block_row;
  133. }
  134. __syncthreads();
  135. }
  136. }
  137. // ====================== TopK softmax things ===============================
  138. /*
  139. A Top-K gating softmax written to exploit when the number of experts in the
  140. MoE layers are a small power of 2. This allows us to cleanly share the rows
  141. among the threads in a single warp and eliminate communication between warps
  142. (so no need to use shared mem).
  143. It fuses the softmax, max and argmax into a single kernel.
  144. Limitations:
  145. 1) This implementation is intended for when the number of experts is a small
  146. power of 2. 2) This implementation assumes k is small, but will work for any
  147. k.
  148. */
  149. template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
  150. __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
  151. void topkGatingSoftmax(const float* input, const bool* finished,
  152. float* output, const int num_rows, int* indices,
  153. int* source_rows, const int k,
  154. const int start_expert, const int end_expert) {
  155. // We begin by enforcing compile time assertions and setting up compile time
  156. // constants.
  157. static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
  158. static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS),
  159. "NUM_EXPERTS must be power of 2");
  160. static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG),
  161. "BYTES_PER_LDG must be power of 2");
  162. static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
  163. // Number of bytes each thread pulls in per load
  164. static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
  165. static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
  166. static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
  167. static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
  168. // Restrictions based on previous section.
  169. static_assert(
  170. VPT % ELTS_PER_LDG == 0,
  171. "The elements per thread must be a multiple of the elements per ldg");
  172. static_assert(WARP_SIZE % THREADS_PER_ROW == 0,
  173. "The threads per row must cleanly divide the threads per warp");
  174. static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW),
  175. "THREADS_PER_ROW must be power of 2");
  176. static_assert(THREADS_PER_ROW <= WARP_SIZE,
  177. "THREADS_PER_ROW can be at most warp size");
  178. // We have NUM_EXPERTS elements per row. We specialize for small #experts
  179. static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
  180. static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
  181. static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
  182. // Restrictions for previous section.
  183. static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0,
  184. "The elts per row must cleanly divide the total elt per warp");
  185. // ===================== From this point, we finally start computing run-time
  186. // variables. ========================
  187. // Compute CTA and warp rows. We pack multiple rows into a single warp, and a
  188. // block contains WARPS_PER_CTA warps. This, each block processes a chunk of
  189. // rows. We start by computing the start row for each block.
  190. const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
  191. // Now, using the base row per thread block, we compute the base row per warp.
  192. const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
  193. // The threads in a warp are split into sub-groups that will work on a row.
  194. // We compute row offset for each thread sub-group
  195. const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
  196. const int thread_row = warp_base_row + thread_row_in_warp;
  197. // Threads with indices out of bounds should early exit here.
  198. if (thread_row >= num_rows) {
  199. return;
  200. }
  201. const bool row_is_active = finished ? !finished[thread_row] : true;
  202. // We finally start setting up the read pointers for each thread. First, each
  203. // thread jumps to the start of the row it will read.
  204. const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
  205. // Now, we compute the group each thread belong to in order to determine the
  206. // first column to start loads.
  207. const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
  208. const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
  209. const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
  210. // Determine the pointer type to use to read in the data depending on the
  211. // BYTES_PER_LDG template param. In theory, this can support all powers of 2
  212. // up to 16. NOTE(woosuk): The original implementation uses CUTLASS aligned
  213. // array here. We defined our own aligned array and use it here to avoid the
  214. // dependency on CUTLASS.
  215. using AccessType = AlignedArray<float, ELTS_PER_LDG>;
  216. // Finally, we pull in the data from global mem
  217. float row_chunk[VPT];
  218. AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
  219. const AccessType* vec_thread_read_ptr =
  220. reinterpret_cast<const AccessType*>(thread_read_ptr);
  221. #pragma unroll
  222. for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
  223. row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
  224. }
  225. // First, we perform a max reduce within the thread. We can do the max in fp16
  226. // safely (I think) and just convert to float afterwards for the exp + sum
  227. // reduction.
  228. float thread_max = row_chunk[0];
  229. #pragma unroll
  230. for (int ii = 1; ii < VPT; ++ii) {
  231. thread_max = max(thread_max, row_chunk[ii]);
  232. }
  233. // Now, we find the max within the thread group and distribute among the
  234. // threads. We use a butterfly reduce.
  235. #pragma unroll
  236. for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
  237. thread_max = max(thread_max, APHRODITE_SHFL_XOR_SYNC_WIDTH(
  238. thread_max, mask, THREADS_PER_ROW));
  239. }
  240. // From this point, thread max in all the threads have the max within the row.
  241. // Now, we subtract the max from each element in the thread and take the exp.
  242. // We also compute the thread local sum.
  243. float row_sum = 0;
  244. #pragma unroll
  245. for (int ii = 0; ii < VPT; ++ii) {
  246. row_chunk[ii] = expf(row_chunk[ii] - thread_max);
  247. row_sum += row_chunk[ii];
  248. }
  249. // Now, we perform the sum reduce within each thread group. Similar to the max
  250. // reduce, we use a bufferfly pattern.
  251. #pragma unroll
  252. for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
  253. row_sum += APHRODITE_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
  254. }
  255. // From this point, all threads have the max and the sum for their rows in the
  256. // thread_max and thread_sum variables respectively. Finally, we can scale the
  257. // rows for the softmax. Technically, for top-k gating we don't need to
  258. // compute the entire softmax row. We can likely look at the maxes and only
  259. // compute for the top-k values in the row. However, this kernel will likely
  260. // not be a bottle neck and it seems better to closer match torch and find the
  261. // argmax after computing the softmax.
  262. const float reciprocal_row_sum = 1.f / row_sum;
  263. #pragma unroll
  264. for (int ii = 0; ii < VPT; ++ii) {
  265. row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
  266. }
  267. // Now, softmax_res contains the softmax of the row chunk. Now, I want to find
  268. // the topk elements in each row, along with the max index.
  269. int start_col = first_elt_read_by_thread;
  270. static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
  271. for (int k_idx = 0; k_idx < k; ++k_idx) {
  272. // First, each thread does the local argmax
  273. float max_val = row_chunk[0];
  274. int expert = start_col;
  275. #pragma unroll
  276. for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD;
  277. ++ldg, col += COLS_PER_GROUP_LDG) {
  278. #pragma unroll
  279. for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
  280. float val = row_chunk[ldg * ELTS_PER_LDG + ii];
  281. // No check on the experts here since columns with the smallest index
  282. // are processed first and only updated if > (not >=)
  283. if (val > max_val) {
  284. max_val = val;
  285. expert = col + ii;
  286. }
  287. }
  288. }
  289. // Now, we perform the argmax reduce. We use the butterfly pattern so threads
  290. // reach consensus about the max. This will be useful for K > 1 so that the
  291. // threads can agree on "who" had the max value. That thread can then blank out
  292. // their max with -inf and the warp can run more iterations...
  293. #pragma unroll
  294. for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
  295. float other_max =
  296. APHRODITE_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
  297. int other_expert =
  298. APHRODITE_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
  299. // We want lower indices to "win" in every thread so we break ties this
  300. // way
  301. if (other_max > max_val ||
  302. (other_max == max_val && other_expert < expert)) {
  303. max_val = other_max;
  304. expert = other_expert;
  305. }
  306. }
  307. // Write the max for this k iteration to global memory.
  308. if (thread_group_idx == 0) {
  309. // Add a guard to ignore experts not included by this node
  310. const bool node_uses_expert =
  311. expert >= start_expert && expert < end_expert;
  312. const bool should_process_row = row_is_active && node_uses_expert;
  313. // The lead thread from each sub-group will write out the final results to
  314. // global memory. (This will be a single) thread per row of the
  315. // input/output matrices.
  316. const int idx = k * thread_row + k_idx;
  317. output[idx] = max_val;
  318. indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
  319. source_rows[idx] = k_idx * num_rows + thread_row;
  320. }
  321. // Finally, we clear the value in the thread with the current max if there
  322. // is another iteration to run.
  323. if (k_idx + 1 < k) {
  324. const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
  325. const int thread_to_clear_in_group =
  326. (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
  327. // Only the thread in the group which produced the max will reset the
  328. // "winning" value to -inf.
  329. if (thread_group_idx == thread_to_clear_in_group) {
  330. const int offset_for_expert = expert % ELTS_PER_LDG;
  331. // Safe to set to any negative value since row_chunk values must be
  332. // between 0 and 1.
  333. row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] =
  334. -10000.f;
  335. }
  336. }
  337. }
  338. }
  339. namespace detail {
  340. // Constructs some constants needed to partition the work across threads at
  341. // compile time.
  342. template <int EXPERTS, int BYTES_PER_LDG>
  343. struct TopkConstants {
  344. static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
  345. static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 ||
  346. EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0,
  347. "");
  348. static constexpr int VECs_PER_THREAD =
  349. MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
  350. static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
  351. static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
  352. static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
  353. };
  354. } // namespace detail
  355. template <int EXPERTS, int WARPS_PER_TB>
  356. void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished,
  357. float* output, int* indices,
  358. int* source_row, const int num_rows,
  359. const int k, const int start_expert,
  360. const int end_expert,
  361. cudaStream_t stream) {
  362. static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
  363. static constexpr int BYTES_PER_LDG =
  364. MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
  365. using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
  366. static constexpr int VPT = Constants::VPT;
  367. static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
  368. const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
  369. const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
  370. dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
  371. topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
  372. <<<num_blocks, block_dim, 0, stream>>>(input, finished, output, num_rows,
  373. indices, source_row, k,
  374. start_expert, end_expert);
  375. }
  376. #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
  377. topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
  378. gating_output, nullptr, topk_weights, topk_indicies, \
  379. token_expert_indices, num_tokens, topk, 0, num_experts, stream);
  380. void topkGatingSoftmaxKernelLauncher(
  381. const float* gating_output, float* topk_weights, int* topk_indicies,
  382. int* token_expert_indices, float* softmax_workspace, const int num_tokens,
  383. const int num_experts, const int topk, cudaStream_t stream) {
  384. static constexpr int WARPS_PER_TB = 4;
  385. switch (num_experts) {
  386. case 1:
  387. LAUNCH_SOFTMAX(1, WARPS_PER_TB);
  388. break;
  389. case 2:
  390. LAUNCH_SOFTMAX(2, WARPS_PER_TB);
  391. break;
  392. case 4:
  393. LAUNCH_SOFTMAX(4, WARPS_PER_TB);
  394. break;
  395. case 8:
  396. LAUNCH_SOFTMAX(8, WARPS_PER_TB);
  397. break;
  398. case 16:
  399. LAUNCH_SOFTMAX(16, WARPS_PER_TB);
  400. break;
  401. case 32:
  402. LAUNCH_SOFTMAX(32, WARPS_PER_TB);
  403. break;
  404. case 64:
  405. LAUNCH_SOFTMAX(64, WARPS_PER_TB);
  406. break;
  407. case 128:
  408. LAUNCH_SOFTMAX(128, WARPS_PER_TB);
  409. break;
  410. case 256:
  411. LAUNCH_SOFTMAX(256, WARPS_PER_TB);
  412. break;
  413. default: {
  414. TORCH_CHECK(softmax_workspace != nullptr,
  415. "softmax_workspace must be provided for num_experts that are "
  416. "not a power of 2.");
  417. static constexpr int TPB = 256;
  418. moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
  419. gating_output, nullptr, softmax_workspace, num_experts);
  420. moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
  421. softmax_workspace, nullptr, topk_weights, topk_indicies,
  422. token_expert_indices, num_experts, topk, 0, num_experts);
  423. }
  424. }
  425. }
  426. } // namespace moe
  427. } // namespace aphrodite
  428. void topk_softmax(torch::Tensor& topk_weights, // [num_tokens, topk]
  429. torch::Tensor& topk_indices, // [num_tokens, topk]
  430. torch::Tensor& token_expert_indices, // [num_tokens, topk]
  431. torch::Tensor& gating_output) // [num_tokens, num_experts]
  432. {
  433. const int num_experts = gating_output.size(-1);
  434. const int num_tokens = gating_output.numel() / num_experts;
  435. const int topk = topk_weights.size(-1);
  436. const bool is_pow_2 =
  437. (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
  438. const bool needs_workspace = !is_pow_2 || num_experts > 256;
  439. const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
  440. const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
  441. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  442. torch::Tensor softmax_workspace =
  443. torch::empty({workspace_size}, gating_output.options());
  444. aphrodite::moe::topkGatingSoftmaxKernelLauncher(
  445. gating_output.data_ptr<float>(), topk_weights.data_ptr<float>(),
  446. topk_indices.data_ptr<int>(), token_expert_indices.data_ptr<int>(),
  447. softmax_workspace.data_ptr<float>(), num_tokens, num_experts, topk,
  448. stream);
  449. }