1
0

softmax.cu 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. /*
  2. * Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
  3. * Copyright (c) 2024, The PygmalionAI team.
  4. * Copyright (c) 2024, The vLLM team.
  5. * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  6. * SPDX-License-Identifier: Apache-2.0
  7. *
  8. * Licensed under the Apache License, Version 2.0 (the "License");
  9. * you may not use this file except in compliance with the License.
  10. * You may obtain a copy of the License at
  11. *
  12. * http://www.apache.org/licenses/LICENSE-2.0
  13. *
  14. * Unless required by applicable law or agreed to in writing, software
  15. * distributed under the License is distributed on an "AS IS" BASIS,
  16. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. * See the License for the specific language governing permissions and
  18. * limitations under the License.
  19. */
  20. #include <torch/extension.h>
  21. #include <ATen/cuda/CUDAContext.h>
  22. #include <c10/cuda/CUDAGuard.h>
  23. #include <cub/cub.cuh>
  24. #include <cub/util_type.cuh>
  25. #include "../cuda_compat.h"
  26. namespace aphrodite {
  27. namespace moe {
  28. static constexpr int WARP_SIZE = 32;
  29. // Aligned array type
  30. template <typename T, int N, int Alignment = sizeof(T) * N>
  31. class alignas(Alignment) AlignedArray {
  32. float data[N];
  33. };
  34. // We have our own implementation of softmax here so we can support transposing the output
  35. // in the softmax kernel when we extend this module to support expert-choice routing.
  36. template <int TPB>
  37. __launch_bounds__(TPB) __global__
  38. void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
  39. {
  40. using BlockReduce = cub::BlockReduce<float, TPB>;
  41. __shared__ typename BlockReduce::TempStorage tmpStorage;
  42. __shared__ float normalizing_factor;
  43. __shared__ float float_max;
  44. const int thread_row_offset = blockIdx.x * num_cols;
  45. cub::Sum sum;
  46. float threadData(-FLT_MAX);
  47. // Don't touch finished rows.
  48. if ((finished != nullptr) && finished[blockIdx.x])
  49. {
  50. return;
  51. }
  52. for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
  53. {
  54. const int idx = thread_row_offset + ii;
  55. threadData = max(static_cast<float>(input[idx]), threadData);
  56. }
  57. const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
  58. if (threadIdx.x == 0)
  59. {
  60. float_max = maxElem;
  61. }
  62. __syncthreads();
  63. threadData = 0;
  64. for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
  65. {
  66. const int idx = thread_row_offset + ii;
  67. threadData += exp((static_cast<float>(input[idx]) - float_max));
  68. }
  69. const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
  70. if (threadIdx.x == 0)
  71. {
  72. normalizing_factor = 1.f / Z;
  73. }
  74. __syncthreads();
  75. for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
  76. {
  77. const int idx = thread_row_offset + ii;
  78. const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
  79. output[idx] = val;
  80. }
  81. }
  82. template <int TPB>
  83. __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
  84. int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
  85. {
  86. using cub_kvp = cub::KeyValuePair<int, float>;
  87. using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
  88. __shared__ typename BlockReduce::TempStorage tmpStorage;
  89. cub_kvp thread_kvp;
  90. cub::ArgMax arg_max;
  91. const int num_rows = gridDim.x;
  92. const int block_row = blockIdx.x;
  93. const bool row_is_active = finished ? !finished[block_row] : true;
  94. const int thread_read_offset = blockIdx.x * num_experts;
  95. for (int k_idx = 0; k_idx < k; ++k_idx)
  96. {
  97. thread_kvp.key = 0;
  98. thread_kvp.value = -1.f; // This is OK because inputs are probabilities
  99. cub_kvp inp_kvp;
  100. for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
  101. {
  102. const int idx = thread_read_offset + expert;
  103. inp_kvp.key = expert;
  104. inp_kvp.value = inputs_after_softmax[idx];
  105. for (int prior_k = 0; prior_k < k_idx; ++prior_k)
  106. {
  107. const int prior_winning_expert = indices[k * block_row + prior_k];
  108. if (prior_winning_expert == expert)
  109. {
  110. inp_kvp = thread_kvp;
  111. }
  112. }
  113. thread_kvp = arg_max(inp_kvp, thread_kvp);
  114. }
  115. const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
  116. if (threadIdx.x == 0)
  117. {
  118. // Ignore experts the node isn't responsible for with expert parallelism
  119. const int expert = result_kvp.key;
  120. const bool node_uses_expert = expert >= start_expert && expert < end_expert;
  121. const bool should_process_row = row_is_active && node_uses_expert;
  122. const int idx = k * block_row + k_idx;
  123. output[idx] = result_kvp.value;
  124. indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
  125. assert(indices[idx] >= 0);
  126. source_rows[idx] = k_idx * num_rows + block_row;
  127. }
  128. __syncthreads();
  129. }
  130. }
  131. // Top-K
  132. template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
  133. __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
  134. void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
  135. int* source_rows, const int k, const int start_expert, const int end_expert)
  136. {
  137. // We begin by enforcing compile time assertions and setting up compile time constants.
  138. static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
  139. static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
  140. static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
  141. static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
  142. // Number of bytes each thread pulls in per load
  143. static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
  144. static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
  145. static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
  146. static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
  147. // more compile-time assertions based on the previous section
  148. static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
  149. static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
  150. static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
  151. static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
  152. static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
  153. static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
  154. static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
  155. static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elts per warp");
  156. // let's finally compute runtime variables
  157. // compute CTA and warp rows. we pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
  158. // each block processes a chunk of rows. Start by computing the start row for each block.
  159. const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
  160. // now, using the base row per thread block, compute the base row per warp
  161. const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
  162. // the threads in a warp are split into sub-groups that will work in a row.
  163. // compute row offset for each thread sub-group
  164. const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
  165. const int thread_row = warp_base_row + thread_row_in_warp;
  166. // threads with indices out of bounds should early exit here
  167. if (thread_row >= num_rows)
  168. {
  169. return;
  170. }
  171. const bool row_is_active = finished ? !finished[thread_row] : true;
  172. // finally start setting up the read pointers for each thread.
  173. // first, each thread jumps to the start of the row it'll read
  174. const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
  175. // now we compute the group each thread belongs to in order to determine the
  176. // first column to start loads
  177. const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
  178. const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
  179. const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
  180. // determine the pointer type to use to read in the data depending on the
  181. // BYTES_PER_LDG template parameter
  182. // this should support all powers of 2 up to 16
  183. // NOTE: the original TensorRT-LLM implementation uses CUTLASS aligned arrays here
  184. // we define our own aligned array and use it here to avoid using CUTLASS
  185. using AccessType = AlignedArray<float, ELTS_PER_LDG>;
  186. // finally, we put in the data from global memory
  187. float row_chunk[VPT];
  188. AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(row_chunk);
  189. const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
  190. #pragma unroll
  191. for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
  192. {
  193. row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
  194. }
  195. // first, we perform a max reduce within the thread. we can do the max in fp16
  196. // safely ( i think) and just convert to float afterwards for the exp + sum reduction
  197. float thread_max = row_chunk[0];
  198. #pragma unroll
  199. for (int ii = 1; ii < VPT; ++ii)
  200. {
  201. thread_max = max(thread_max, row_chunk[ii]);
  202. }
  203. // now, we find the max within the thread group and distribute among the threads. we use the butterfly
  204. // all-reduce algorithm
  205. #pragma unroll
  206. for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
  207. {
  208. thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
  209. }
  210. // from this point, thread max in all the threads have the max within the row.
  211. // now, we subtract the max from each element in the thread and take the exp.
  212. // we also compute the thread local sum
  213. float row_sum = 0;
  214. #pragma unroll
  215. for (int ii = 0; ii < VPT; ++ii)
  216. {
  217. row_chunk[ii] = exp(row_chunk[ii] - thread_max);
  218. row_sum += row_chunk[ii];
  219. }
  220. // now we perform a sum reduction within the thread group
  221. // we use the butterfly all-reduce algorithm
  222. #pragma unroll
  223. for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
  224. {
  225. row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
  226. }
  227. // from this point, all threads have the max and the sum for their rows in the
  228. // thread_max and thread_sum variables respectively
  229. // finally, we can scale the rows for the softmax. technically, for top-k gating
  230. // we don't need to compute the entire softmax row. we can likely look at the
  231. // maxes and only compute for the top-k values in the row.
  232. // this kernel will likely not be a bottleneck
  233. const float reciprocal_row_sum = 1.f / row_sum;
  234. #pragma unroll
  235. for (int ii = 0; ii < VPT; ++ii)
  236. {
  237. row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
  238. }
  239. // now, softmax_res contains the softmax of the row chunk. now, let's find the
  240. // top-k elements in each row, along with the max index
  241. int start_col = first_elt_read_by_thread;
  242. static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
  243. for (int k_idx = 0; k_idx < k; ++k_idx)
  244. {
  245. // first each thread does the local argmax
  246. float max_val = row_chunk[0];
  247. int expert = start_col;
  248. #pragma unroll
  249. for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
  250. {
  251. #pragma unroll
  252. for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
  253. {
  254. float val = row_chunk[ldg * ELTS_PER_LDG + ii];
  255. // no check on the experts here since columns with the smallest index are processed
  256. // first and only updated if > (not >=)
  257. if (val > max_val)
  258. {
  259. max_val = val;
  260. expert = col + ii;
  261. }
  262. }
  263. }
  264. // now pwe perform the argmax reduce. we use the butterfly pattern again so threads reach
  265. // consensus about the max. this will be useful for K > 1 so that the threads can agree on "who"
  266. // had the max value. that thread can then blank out their max with -inf and the warp can run
  267. // more iterations
  268. #pragma unroll
  269. for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
  270. {
  271. float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
  272. int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
  273. // we want lower indices to "win" in every thread so we break ties this way
  274. if (other_max > max_val || (other_max == max_val && other_expert < expert))
  275. {
  276. max_val = other_max;
  277. expert = other_expert;
  278. }
  279. }
  280. // write the max for this k iteration to global memory
  281. if (thread_group_idx == 0)
  282. {
  283. // add a guard to ignore experts not included by this node
  284. const bool node_uses_expert = expert >= start_expert && expert < end_expert;
  285. const bool should_process_row = row_is_active && node_uses_expert;
  286. // this lead thread from each sub-group will write out the final results to global memory
  287. // (This will be a single) thread per row of the input/output matrices
  288. const int idx = k * thread_row + k_idx;
  289. output[idx] = max_val;
  290. indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
  291. source_rows[idx] = k_idx * num_rows + thread_row;
  292. }
  293. // finally, we clear the value in the thread with the current max if there is another iteration
  294. // to run
  295. if (k_idx + 1 < k)
  296. {
  297. const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
  298. const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
  299. // only the thread in the group which produced the max will reset the "winning"
  300. // value to -inf
  301. if (thread_group_idx == thread_to_clear_in_group)
  302. {
  303. const int offset_for_expert = expert % ELTS_PER_LDG;
  304. // safe to set to any negative value since row_chunk values must be between 0 and 1
  305. row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
  306. }
  307. }
  308. }
  309. }
  310. namespace detail
  311. {
  312. // constructs some constants needed to partition the work across threads at compile time
  313. template <int EXPERTS, int BYTES_PER_LDG>
  314. struct TopkConstants
  315. {
  316. static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
  317. static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
  318. static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
  319. static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
  320. static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
  321. static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
  322. };
  323. } // namespace detail
  324. template <int EXPERTS, int WARPS_PER_TB>
  325. void topkGatingSoftmaxLauncherHelper(
  326. const float* input, const bool* finished, float* output, int* indices, int* source_row,
  327. const int num_rows, const int k, const int start_expert, const int end_expert,
  328. cudaStream_t stream)
  329. {
  330. static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
  331. static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
  332. using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
  333. static constexpr int VPT = Constants::VPT;
  334. static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
  335. const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
  336. const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
  337. dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
  338. topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
  339. input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
  340. }
  341. #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
  342. topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
  343. gating_output, nullptr, topk_weights, topk_indices, \
  344. token_expert_indices, num_tokens, topk, 0, num_experts, \
  345. stream);
  346. void topkGatingSoftmaxKernelLauncher(
  347. const float* gating_output,
  348. float* topk_weights,
  349. int* topk_indices,
  350. int* token_expert_indices,
  351. float* softmax_workspace,
  352. const int num_tokens,
  353. const int num_experts,
  354. const int topk,
  355. cudaStream_t stream) {
  356. static constexpr int WARPS_PER_TB = 4;
  357. switch (num_experts) {
  358. case 1:
  359. LAUNCH_SOFTMAX(1, WARPS_PER_TB);
  360. break;
  361. case 2:
  362. LAUNCH_SOFTMAX(2, WARPS_PER_TB);
  363. break;
  364. case 4:
  365. LAUNCH_SOFTMAX(4, WARPS_PER_TB);
  366. break;
  367. case 8:
  368. LAUNCH_SOFTMAX(8, WARPS_PER_TB);
  369. break;
  370. case 16:
  371. LAUNCH_SOFTMAX(16, WARPS_PER_TB);
  372. break;
  373. case 32:
  374. LAUNCH_SOFTMAX(32, WARPS_PER_TB);
  375. break;
  376. case 64:
  377. LAUNCH_SOFTMAX(64, WARPS_PER_TB);
  378. break;
  379. case 128:
  380. LAUNCH_SOFTMAX(128, WARPS_PER_TB);
  381. break;
  382. case 256:
  383. LAUNCH_SOFTMAX(256, WARPS_PER_TB);
  384. break;
  385. default: {
  386. TORCH_CHECK(softmax_workspace != nullptr,
  387. "softmax_workspace must be provided for num_experts that aren't a power of 2.");
  388. static constexpr int TPB = 256;
  389. moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
  390. gating_output, nullptr, softmax_workspace, num_experts);
  391. moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
  392. softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices,
  393. num_experts, topk, 0, num_experts);
  394. }
  395. }
  396. }
  397. } // namespace moe
  398. } // namespace aphrodite
  399. void topk_softmax(
  400. torch::Tensor& topk_weights,
  401. torch::Tensor& topk_indices,
  402. torch::Tensor& token_expert_indices,
  403. torch::Tensor& gating_output)
  404. {
  405. const int num_experts = gating_output.size(-1);
  406. const int num_tokens = gating_output.numel() / num_experts;
  407. const int topk = topk_weights.size(-1);
  408. const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
  409. const bool needs_workspace = !is_pow_2 || num_experts > 256;
  410. const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
  411. const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
  412. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  413. torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
  414. aphrodite::moe::topkGatingSoftmaxKernelLauncher(
  415. gating_output.data_ptr<float>(),
  416. topk_weights.data_ptr<float>(),
  417. topk_indices.data_ptr<int>(),
  418. token_expert_indices.data_ptr<int>(),
  419. softmax_workspace.data_ptr<float>(),
  420. num_tokens,
  421. num_experts,
  422. topk,
  423. stream);
  424. }