softmax.cu 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. /*
  2. * Adapted from
  3. * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
  4. * Copyright (c) 2024, The PygmalionAI team.
  5. * Copyright (c) 2024, The vLLM team.
  6. * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
  7. * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
  8. *
  9. * Licensed under the Apache License, Version 2.0 (the "License");
  10. * you may not use this file except in compliance with the License.
  11. * You may obtain a copy of the License at
  12. *
  13. * http://www.apache.org/licenses/LICENSE-2.0
  14. *
  15. * Unless required by applicable law or agreed to in writing, software
  16. * distributed under the License is distributed on an "AS IS" BASIS,
  17. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. * See the License for the specific language governing permissions and
  19. * limitations under the License.
  20. */
  21. #include <torch/all.h>
  22. #include <ATen/cuda/CUDAContext.h>
  23. #include <c10/cuda/CUDAGuard.h>
  24. #include "../cuda_compat.h"
  25. #ifndef USE_ROCM
  26. #include <cub/util_type.cuh>
  27. #include <cub/cub.cuh>
  28. #else
  29. #include <hipcub/util_type.hpp>
  30. #include <hipcub/hipcub.hpp>
  31. #endif
  32. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  33. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  34. namespace aphrodite {
  35. namespace moe {
  36. /// Aligned array type
  37. template <typename T,
  38. /// Number of elements in the array
  39. int N,
  40. /// Alignment requirement in bytes
  41. int Alignment = sizeof(T) * N>
  42. class alignas(Alignment) AlignedArray {
  43. float data[N];
  44. };
  45. // ====================== Softmax things ===============================
  46. // We have our own implementation of softmax here so we can support transposing
  47. // the output in the softmax kernel when we extend this module to support
  48. // expert-choice routing.
  49. template <int TPB>
  50. __launch_bounds__(TPB) __global__
  51. void moeSoftmax(const float* input, const bool* finished, float* output,
  52. const int num_cols) {
  53. using BlockReduce = cub::BlockReduce<float, TPB>;
  54. __shared__ typename BlockReduce::TempStorage tmpStorage;
  55. __shared__ float normalizing_factor;
  56. __shared__ float float_max;
  57. const int thread_row_offset = blockIdx.x * num_cols;
  58. cub::Sum sum;
  59. float threadData(-FLT_MAX);
  60. // Don't touch finished rows.
  61. if ((finished != nullptr) && finished[blockIdx.x]) {
  62. return;
  63. }
  64. for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
  65. const int idx = thread_row_offset + ii;
  66. threadData = max(static_cast<float>(input[idx]), threadData);
  67. }
  68. const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
  69. if (threadIdx.x == 0) {
  70. float_max = maxElem;
  71. }
  72. __syncthreads();
  73. threadData = 0;
  74. for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
  75. const int idx = thread_row_offset + ii;
  76. threadData += exp((static_cast<float>(input[idx]) - float_max));
  77. }
  78. const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
  79. if (threadIdx.x == 0) {
  80. normalizing_factor = 1.f / Z;
  81. }
  82. __syncthreads();
  83. for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
  84. const int idx = thread_row_offset + ii;
  85. const float val =
  86. exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
  87. output[idx] = val;
  88. }
  89. }
  90. template <int TPB>
  91. __launch_bounds__(TPB) __global__
  92. void moeTopK(const float* inputs_after_softmax, const bool* finished,
  93. float* output, int* indices, int* source_rows,
  94. const int num_experts, const int k, const int start_expert,
  95. const int end_expert) {
  96. using cub_kvp = cub::KeyValuePair<int, float>;
  97. using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
  98. __shared__ typename BlockReduce::TempStorage tmpStorage;
  99. cub_kvp thread_kvp;
  100. cub::ArgMax arg_max;
  101. const int num_rows = gridDim.x;
  102. const int block_row = blockIdx.x;
  103. const bool row_is_active = finished ? !finished[block_row] : true;
  104. const int thread_read_offset = blockIdx.x * num_experts;
  105. for (int k_idx = 0; k_idx < k; ++k_idx) {
  106. thread_kvp.key = 0;
  107. thread_kvp.value = -1.f; // This is OK because inputs are probabilities
  108. cub_kvp inp_kvp;
  109. for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
  110. const int idx = thread_read_offset + expert;
  111. inp_kvp.key = expert;
  112. inp_kvp.value = inputs_after_softmax[idx];
  113. for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
  114. const int prior_winning_expert = indices[k * block_row + prior_k];
  115. if (prior_winning_expert == expert) {
  116. inp_kvp = thread_kvp;
  117. }
  118. }
  119. thread_kvp = arg_max(inp_kvp, thread_kvp);
  120. }
  121. const cub_kvp result_kvp =
  122. BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
  123. if (threadIdx.x == 0) {
  124. // Ignore experts the node isn't responsible for with expert parallelism
  125. const int expert = result_kvp.key;
  126. const bool node_uses_expert =
  127. expert >= start_expert && expert < end_expert;
  128. const bool should_process_row = row_is_active && node_uses_expert;
  129. const int idx = k * block_row + k_idx;
  130. output[idx] = result_kvp.value;
  131. indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
  132. assert(indices[idx] >= 0);
  133. source_rows[idx] = k_idx * num_rows + block_row;
  134. }
  135. __syncthreads();
  136. }
  137. }
  138. // ====================== TopK softmax things ===============================
  139. /*
  140. A Top-K gating softmax written to exploit when the number of experts in the
  141. MoE layers are a small power of 2. This allows us to cleanly share the rows
  142. among the threads in a single warp and eliminate communication between warps
  143. (so no need to use shared mem).
  144. It fuses the softmax, max and argmax into a single kernel.
  145. Limitations:
  146. 1) This implementation is intended for when the number of experts is a small
  147. power of 2. 2) This implementation assumes k is small, but will work for any
  148. k.
  149. */
  150. template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
  151. __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
  152. void topkGatingSoftmax(const float* input, const bool* finished,
  153. float* output, const int num_rows, int* indices,
  154. int* source_rows, const int k,
  155. const int start_expert, const int end_expert) {
  156. // We begin by enforcing compile time assertions and setting up compile time
  157. // constants.
  158. static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
  159. static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS),
  160. "NUM_EXPERTS must be power of 2");
  161. static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG),
  162. "BYTES_PER_LDG must be power of 2");
  163. static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
  164. // Number of bytes each thread pulls in per load
  165. static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
  166. static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
  167. static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
  168. static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
  169. // Restrictions based on previous section.
  170. static_assert(
  171. VPT % ELTS_PER_LDG == 0,
  172. "The elements per thread must be a multiple of the elements per ldg");
  173. static_assert(WARP_SIZE % THREADS_PER_ROW == 0,
  174. "The threads per row must cleanly divide the threads per warp");
  175. static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW),
  176. "THREADS_PER_ROW must be power of 2");
  177. static_assert(THREADS_PER_ROW <= WARP_SIZE,
  178. "THREADS_PER_ROW can be at most warp size");
  179. // We have NUM_EXPERTS elements per row. We specialize for small #experts
  180. static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
  181. static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
  182. static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
  183. // Restrictions for previous section.
  184. static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0,
  185. "The elts per row must cleanly divide the total elt per warp");
  186. // ===================== From this point, we finally start computing run-time
  187. // variables. ========================
  188. // Compute CTA and warp rows. We pack multiple rows into a single warp, and a
  189. // block contains WARPS_PER_CTA warps. This, each block processes a chunk of
  190. // rows. We start by computing the start row for each block.
  191. const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
  192. // Now, using the base row per thread block, we compute the base row per warp.
  193. const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
  194. // The threads in a warp are split into sub-groups that will work on a row.
  195. // We compute row offset for each thread sub-group
  196. const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
  197. const int thread_row = warp_base_row + thread_row_in_warp;
  198. // Threads with indices out of bounds should early exit here.
  199. if (thread_row >= num_rows) {
  200. return;
  201. }
  202. const bool row_is_active = finished ? !finished[thread_row] : true;
  203. // We finally start setting up the read pointers for each thread. First, each
  204. // thread jumps to the start of the row it will read.
  205. const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
  206. // Now, we compute the group each thread belong to in order to determine the
  207. // first column to start loads.
  208. const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
  209. const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
  210. const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
  211. // Determine the pointer type to use to read in the data depending on the
  212. // BYTES_PER_LDG template param. In theory, this can support all powers of 2
  213. // up to 16. NOTE: The original implementation uses CUTLASS aligned array
  214. // here. We defined our own aligned array and use it here to avoid the
  215. // dependency on CUTLASS.
  216. using AccessType = AlignedArray<float, ELTS_PER_LDG>;
  217. // Finally, we pull in the data from global mem
  218. float row_chunk[VPT];
  219. AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
  220. const AccessType* vec_thread_read_ptr =
  221. reinterpret_cast<const AccessType*>(thread_read_ptr);
  222. #pragma unroll
  223. for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
  224. row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
  225. }
  226. // First, we perform a max reduce within the thread. We can do the max in fp16
  227. // safely (I think) and just convert to float afterwards for the exp + sum
  228. // reduction.
  229. float thread_max = row_chunk[0];
  230. #pragma unroll
  231. for (int ii = 1; ii < VPT; ++ii) {
  232. thread_max = max(thread_max, row_chunk[ii]);
  233. }
  234. // Now, we find the max within the thread group and distribute among the
  235. // threads. We use a butterfly reduce.
  236. #pragma unroll
  237. for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
  238. thread_max = max(thread_max, APHRODITE_SHFL_XOR_SYNC_WIDTH(
  239. thread_max, mask, THREADS_PER_ROW));
  240. }
  241. // From this point, thread max in all the threads have the max within the row.
  242. // Now, we subtract the max from each element in the thread and take the exp.
  243. // We also compute the thread local sum.
  244. float row_sum = 0;
  245. #pragma unroll
  246. for (int ii = 0; ii < VPT; ++ii) {
  247. row_chunk[ii] = expf(row_chunk[ii] - thread_max);
  248. row_sum += row_chunk[ii];
  249. }
  250. // Now, we perform the sum reduce within each thread group. Similar to the max
  251. // reduce, we use a bufferfly pattern.
  252. #pragma unroll
  253. for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
  254. row_sum += APHRODITE_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
  255. }
  256. // From this point, all threads have the max and the sum for their rows in the
  257. // thread_max and thread_sum variables respectively. Finally, we can scale the
  258. // rows for the softmax. Technically, for top-k gating we don't need to
  259. // compute the entire softmax row. We can likely look at the maxes and only
  260. // compute for the top-k values in the row. However, this kernel will likely
  261. // not be a bottle neck and it seems better to closer match torch and find the
  262. // argmax after computing the softmax.
  263. const float reciprocal_row_sum = 1.f / row_sum;
  264. #pragma unroll
  265. for (int ii = 0; ii < VPT; ++ii) {
  266. row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
  267. }
  268. // Now, softmax_res contains the softmax of the row chunk. Now, I want to find
  269. // the topk elements in each row, along with the max index.
  270. int start_col = first_elt_read_by_thread;
  271. static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
  272. for (int k_idx = 0; k_idx < k; ++k_idx) {
  273. // First, each thread does the local argmax
  274. float max_val = row_chunk[0];
  275. int expert = start_col;
  276. #pragma unroll
  277. for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD;
  278. ++ldg, col += COLS_PER_GROUP_LDG) {
  279. #pragma unroll
  280. for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
  281. float val = row_chunk[ldg * ELTS_PER_LDG + ii];
  282. // No check on the experts here since columns with the smallest index
  283. // are processed first and only updated if > (not >=)
  284. if (val > max_val) {
  285. max_val = val;
  286. expert = col + ii;
  287. }
  288. }
  289. }
  290. // Now, we perform the argmax reduce. We use the butterfly pattern so threads
  291. // reach consensus about the max. This will be useful for K > 1 so that the
  292. // threads can agree on "who" had the max value. That thread can then blank out
  293. // their max with -inf and the warp can run more iterations...
  294. #pragma unroll
  295. for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
  296. float other_max =
  297. APHRODITE_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
  298. int other_expert =
  299. APHRODITE_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
  300. // We want lower indices to "win" in every thread so we break ties this
  301. // way
  302. if (other_max > max_val ||
  303. (other_max == max_val && other_expert < expert)) {
  304. max_val = other_max;
  305. expert = other_expert;
  306. }
  307. }
  308. // Write the max for this k iteration to global memory.
  309. if (thread_group_idx == 0) {
  310. // Add a guard to ignore experts not included by this node
  311. const bool node_uses_expert =
  312. expert >= start_expert && expert < end_expert;
  313. const bool should_process_row = row_is_active && node_uses_expert;
  314. // The lead thread from each sub-group will write out the final results to
  315. // global memory. (This will be a single) thread per row of the
  316. // input/output matrices.
  317. const int idx = k * thread_row + k_idx;
  318. output[idx] = max_val;
  319. indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
  320. source_rows[idx] = k_idx * num_rows + thread_row;
  321. }
  322. // Finally, we clear the value in the thread with the current max if there
  323. // is another iteration to run.
  324. if (k_idx + 1 < k) {
  325. const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
  326. const int thread_to_clear_in_group =
  327. (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
  328. // Only the thread in the group which produced the max will reset the
  329. // "winning" value to -inf.
  330. if (thread_group_idx == thread_to_clear_in_group) {
  331. const int offset_for_expert = expert % ELTS_PER_LDG;
  332. // Safe to set to any negative value since row_chunk values must be
  333. // between 0 and 1.
  334. row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] =
  335. -10000.f;
  336. }
  337. }
  338. }
  339. }
  340. namespace detail {
  341. // Constructs some constants needed to partition the work across threads at
  342. // compile time.
  343. template <int EXPERTS, int BYTES_PER_LDG>
  344. struct TopkConstants {
  345. static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
  346. static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 ||
  347. EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0,
  348. "");
  349. static constexpr int VECs_PER_THREAD =
  350. MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
  351. static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
  352. static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
  353. static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
  354. };
  355. } // namespace detail
  356. template <int EXPERTS, int WARPS_PER_TB>
  357. void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished,
  358. float* output, int* indices,
  359. int* source_row, const int num_rows,
  360. const int k, const int start_expert,
  361. const int end_expert,
  362. cudaStream_t stream) {
  363. static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
  364. static constexpr int BYTES_PER_LDG =
  365. MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
  366. using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
  367. static constexpr int VPT = Constants::VPT;
  368. static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
  369. const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
  370. const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
  371. dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
  372. topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
  373. <<<num_blocks, block_dim, 0, stream>>>(input, finished, output, num_rows,
  374. indices, source_row, k,
  375. start_expert, end_expert);
  376. }
  377. #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
  378. topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
  379. gating_output, nullptr, topk_weights, topk_indicies, \
  380. token_expert_indices, num_tokens, topk, 0, num_experts, stream);
  381. void topkGatingSoftmaxKernelLauncher(
  382. const float* gating_output, float* topk_weights, int* topk_indicies,
  383. int* token_expert_indices, float* softmax_workspace, const int num_tokens,
  384. const int num_experts, const int topk, cudaStream_t stream) {
  385. static constexpr int WARPS_PER_TB = 4;
  386. switch (num_experts) {
  387. case 1:
  388. LAUNCH_SOFTMAX(1, WARPS_PER_TB);
  389. break;
  390. case 2:
  391. LAUNCH_SOFTMAX(2, WARPS_PER_TB);
  392. break;
  393. case 4:
  394. LAUNCH_SOFTMAX(4, WARPS_PER_TB);
  395. break;
  396. case 8:
  397. LAUNCH_SOFTMAX(8, WARPS_PER_TB);
  398. break;
  399. case 16:
  400. LAUNCH_SOFTMAX(16, WARPS_PER_TB);
  401. break;
  402. case 32:
  403. LAUNCH_SOFTMAX(32, WARPS_PER_TB);
  404. break;
  405. case 64:
  406. LAUNCH_SOFTMAX(64, WARPS_PER_TB);
  407. break;
  408. case 128:
  409. LAUNCH_SOFTMAX(128, WARPS_PER_TB);
  410. break;
  411. case 256:
  412. LAUNCH_SOFTMAX(256, WARPS_PER_TB);
  413. break;
  414. default: {
  415. TORCH_CHECK(softmax_workspace != nullptr,
  416. "softmax_workspace must be provided for num_experts that are "
  417. "not a power of 2.");
  418. static constexpr int TPB = 256;
  419. moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
  420. gating_output, nullptr, softmax_workspace, num_experts);
  421. moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
  422. softmax_workspace, nullptr, topk_weights, topk_indicies,
  423. token_expert_indices, num_experts, topk, 0, num_experts);
  424. }
  425. }
  426. }
  427. } // namespace moe
  428. } // namespace aphrodite
  429. void topk_softmax(torch::Tensor& topk_weights, // [num_tokens, topk]
  430. torch::Tensor& topk_indices, // [num_tokens, topk]
  431. torch::Tensor& token_expert_indices, // [num_tokens, topk]
  432. torch::Tensor& gating_output) // [num_tokens, num_experts]
  433. {
  434. const int num_experts = gating_output.size(-1);
  435. const int num_tokens = gating_output.numel() / num_experts;
  436. const int topk = topk_weights.size(-1);
  437. const bool is_pow_2 =
  438. (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
  439. const bool needs_workspace = !is_pow_2 || num_experts > 256;
  440. const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
  441. const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
  442. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  443. torch::Tensor softmax_workspace =
  444. torch::empty({workspace_size}, gating_output.options());
  445. aphrodite::moe::topkGatingSoftmaxKernelLauncher(
  446. gating_output.data_ptr<float>(), topk_weights.data_ptr<float>(),
  447. topk_indices.data_ptr<int>(), token_expert_indices.data_ptr<int>(),
  448. softmax_workspace.data_ptr<float>(), num_tokens, num_experts, topk,
  449. stream);
  450. }