softmax.cu 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. /*
  2. * Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
  3. * Copyright (c) 2024, The PygmalionAI team.
  4. * Copyright (c) 2024, The vLLM team.
  5. * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  6. * SPDX-License-Identifier: Apache-2.0
  7. *
  8. * Licensed under the Apache License, Version 2.0 (the "License");
  9. * you may not use this file except in compliance with the License.
  10. * You may obtain a copy of the License at
  11. *
  12. * http://www.apache.org/licenses/LICENSE-2.0
  13. *
  14. * Unless required by applicable law or agreed to in writing, software
  15. * distributed under the License is distributed on an "AS IS" BASIS,
  16. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. * See the License for the specific language governing permissions and
  18. * limitations under the License.
  19. */
  20. #include <torch/extension.h>
  21. #include <ATen/cuda/CUDAContext.h>
  22. #include <c10/cuda/CUDAGuard.h>
  23. #include <cub/cub.cuh>
  24. #include <cub/util_type.cuh>
  25. namespace aphrodite {
  26. namespace moe {
  27. static constexpr int WARP_SIZE = 32;
  28. /// Aligned array type
  29. template <
  30. typename T,
  31. /// Number of elements in the array
  32. int N,
  33. /// Alignment requirement in bytes
  34. int Alignment = sizeof(T) * N
  35. >
  36. class alignas(Alignment) AlignedArray {
  37. float data[N];
  38. };
  39. // ====================== Softmax things ===============================
  40. // We have our own implementation of softmax here so we can support transposing the output
  41. // in the softmax kernel when we extend this module to support expert-choice routing.
  42. template <int TPB>
  43. __launch_bounds__(TPB) __global__
  44. void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
  45. {
  46. using BlockReduce = cub::BlockReduce<float, TPB>;
  47. __shared__ typename BlockReduce::TempStorage tmpStorage;
  48. __shared__ float normalizing_factor;
  49. __shared__ float float_max;
  50. const int thread_row_offset = blockIdx.x * num_cols;
  51. cub::Sum sum;
  52. float threadData(-FLT_MAX);
  53. // Don't touch finished rows.
  54. if ((finished != nullptr) && finished[blockIdx.x])
  55. {
  56. return;
  57. }
  58. for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
  59. {
  60. const int idx = thread_row_offset + ii;
  61. threadData = max(static_cast<float>(input[idx]), threadData);
  62. }
  63. const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
  64. if (threadIdx.x == 0)
  65. {
  66. float_max = maxElem;
  67. }
  68. __syncthreads();
  69. threadData = 0;
  70. for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
  71. {
  72. const int idx = thread_row_offset + ii;
  73. threadData += exp((static_cast<float>(input[idx]) - float_max));
  74. }
  75. const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
  76. if (threadIdx.x == 0)
  77. {
  78. normalizing_factor = 1.f / Z;
  79. }
  80. __syncthreads();
  81. for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
  82. {
  83. const int idx = thread_row_offset + ii;
  84. const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
  85. output[idx] = val;
  86. }
  87. }
  88. template <int TPB>
  89. __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
  90. int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
  91. {
  92. using cub_kvp = cub::KeyValuePair<int, float>;
  93. using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
  94. __shared__ typename BlockReduce::TempStorage tmpStorage;
  95. cub_kvp thread_kvp;
  96. cub::ArgMax arg_max;
  97. const int num_rows = gridDim.x;
  98. const int block_row = blockIdx.x;
  99. const bool row_is_active = finished ? !finished[block_row] : true;
  100. const int thread_read_offset = blockIdx.x * num_experts;
  101. for (int k_idx = 0; k_idx < k; ++k_idx)
  102. {
  103. thread_kvp.key = 0;
  104. thread_kvp.value = -1.f; // This is OK because inputs are probabilities
  105. cub_kvp inp_kvp;
  106. for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
  107. {
  108. const int idx = thread_read_offset + expert;
  109. inp_kvp.key = expert;
  110. inp_kvp.value = inputs_after_softmax[idx];
  111. for (int prior_k = 0; prior_k < k_idx; ++prior_k)
  112. {
  113. const int prior_winning_expert = indices[k * block_row + prior_k];
  114. if (prior_winning_expert == expert)
  115. {
  116. inp_kvp = thread_kvp;
  117. }
  118. }
  119. thread_kvp = arg_max(inp_kvp, thread_kvp);
  120. }
  121. const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
  122. if (threadIdx.x == 0)
  123. {
  124. // Ignore experts the node isn't responsible for with expert parallelism
  125. const int expert = result_kvp.key;
  126. const bool node_uses_expert = expert >= start_expert && expert < end_expert;
  127. const bool should_process_row = row_is_active && node_uses_expert;
  128. const int idx = k * block_row + k_idx;
  129. output[idx] = result_kvp.value;
  130. indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
  131. assert(indices[idx] >= 0);
  132. source_rows[idx] = k_idx * num_rows + block_row;
  133. }
  134. __syncthreads();
  135. }
  136. }
  137. // ====================== TopK softmax things ===============================
  138. /*
  139. A Top-K gating softmax written to exploit when the number of experts in the MoE layers
  140. are a small power of 2. This allows us to cleanly share the rows among the threads in
  141. a single warp and eliminate communication between warps (so no need to use shared mem).
  142. It fuses the softmax, max and argmax into a single kernel.
  143. Limitations:
  144. 1) This implementation is intended for when the number of experts is a small power of 2.
  145. 2) This implementation assumes k is small, but will work for any k.
  146. */
  147. template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
  148. __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
  149. void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
  150. int* source_rows, const int k, const int start_expert, const int end_expert)
  151. {
  152. // We begin by enforcing compile time assertions and setting up compile time constants.
  153. static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
  154. static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
  155. static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
  156. static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
  157. // Number of bytes each thread pulls in per load
  158. static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
  159. static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
  160. static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
  161. static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
  162. // Restrictions based on previous section.
  163. static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
  164. static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
  165. static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
  166. static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
  167. // We have NUM_EXPERTS elements per row. We specialize for small #experts
  168. static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
  169. static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
  170. static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
  171. // Restrictions for previous section.
  172. static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
  173. // ===================== From this point, we finally start computing run-time variables. ========================
  174. // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
  175. // This, each block processes a chunk of rows. We start by computing the start row for each block.
  176. const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
  177. // Now, using the base row per thread block, we compute the base row per warp.
  178. const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
  179. // The threads in a warp are split into sub-groups that will work on a row.
  180. // We compute row offset for each thread sub-group
  181. const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
  182. const int thread_row = warp_base_row + thread_row_in_warp;
  183. // Threads with indices out of bounds should early exit here.
  184. if (thread_row >= num_rows)
  185. {
  186. return;
  187. }
  188. const bool row_is_active = finished ? !finished[thread_row] : true;
  189. // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
  190. // row it will read.
  191. const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
  192. // Now, we compute the group each thread belong to in order to determine the first column to start loads.
  193. const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
  194. const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
  195. const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
  196. // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
  197. // this can support all powers of 2 up to 16.
  198. // NOTE: The original implementation uses CUTLASS aligned array here.
  199. // We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
  200. using AccessType = AlignedArray<float, ELTS_PER_LDG>;
  201. // Finally, we pull in the data from global mem
  202. float row_chunk[VPT];
  203. AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
  204. const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
  205. #pragma unroll
  206. for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
  207. {
  208. row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
  209. }
  210. // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
  211. // convert to float afterwards for the exp + sum reduction.
  212. float thread_max = row_chunk[0];
  213. #pragma unroll
  214. for (int ii = 1; ii < VPT; ++ii)
  215. {
  216. thread_max = max(thread_max, row_chunk[ii]);
  217. }
  218. // Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
  219. #pragma unroll
  220. for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
  221. {
  222. thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
  223. }
  224. // From this point, thread max in all the threads have the max within the row.
  225. // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
  226. float row_sum = 0;
  227. #pragma unroll
  228. for (int ii = 0; ii < VPT; ++ii)
  229. {
  230. row_chunk[ii] = expf(row_chunk[ii] - thread_max);
  231. row_sum += row_chunk[ii];
  232. }
  233. // Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
  234. #pragma unroll
  235. for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
  236. {
  237. row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
  238. }
  239. // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
  240. // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
  241. // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
  242. // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
  243. // argmax after computing the softmax.
  244. const float reciprocal_row_sum = 1.f / row_sum;
  245. #pragma unroll
  246. for (int ii = 0; ii < VPT; ++ii)
  247. {
  248. row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
  249. }
  250. // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
  251. // with the max index.
  252. int start_col = first_elt_read_by_thread;
  253. static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
  254. for (int k_idx = 0; k_idx < k; ++k_idx)
  255. {
  256. // First, each thread does the local argmax
  257. float max_val = row_chunk[0];
  258. int expert = start_col;
  259. #pragma unroll
  260. for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
  261. {
  262. #pragma unroll
  263. for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
  264. {
  265. float val = row_chunk[ldg * ELTS_PER_LDG + ii];
  266. // No check on the experts here since columns with the smallest index are processed first and only
  267. // updated if > (not >=)
  268. if (val > max_val)
  269. {
  270. max_val = val;
  271. expert = col + ii;
  272. }
  273. }
  274. }
  275. // Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
  276. // This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
  277. // then blank out their max with -inf and the warp can run more iterations...
  278. #pragma unroll
  279. for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
  280. {
  281. float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
  282. int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
  283. // We want lower indices to "win" in every thread so we break ties this way
  284. if (other_max > max_val || (other_max == max_val && other_expert < expert))
  285. {
  286. max_val = other_max;
  287. expert = other_expert;
  288. }
  289. }
  290. // Write the max for this k iteration to global memory.
  291. if (thread_group_idx == 0)
  292. {
  293. // Add a guard to ignore experts not included by this node
  294. const bool node_uses_expert = expert >= start_expert && expert < end_expert;
  295. const bool should_process_row = row_is_active && node_uses_expert;
  296. // The lead thread from each sub-group will write out the final results to global memory. (This will be a
  297. // single) thread per row of the input/output matrices.
  298. const int idx = k * thread_row + k_idx;
  299. output[idx] = max_val;
  300. indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
  301. source_rows[idx] = k_idx * num_rows + thread_row;
  302. }
  303. // Finally, we clear the value in the thread with the current max if there is another iteration to run.
  304. if (k_idx + 1 < k)
  305. {
  306. const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
  307. const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
  308. // Only the thread in the group which produced the max will reset the "winning" value to -inf.
  309. if (thread_group_idx == thread_to_clear_in_group)
  310. {
  311. const int offset_for_expert = expert % ELTS_PER_LDG;
  312. // Safe to set to any negative value since row_chunk values must be between 0 and 1.
  313. row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
  314. }
  315. }
  316. }
  317. }
  318. namespace detail
  319. {
  320. // Constructs some constants needed to partition the work across threads at compile time.
  321. template <int EXPERTS, int BYTES_PER_LDG>
  322. struct TopkConstants
  323. {
  324. static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
  325. static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
  326. static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
  327. static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
  328. static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
  329. static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
  330. };
  331. } // namespace detail
  332. template <int EXPERTS, int WARPS_PER_TB>
  333. void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
  334. int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
  335. {
  336. static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
  337. static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
  338. using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
  339. static constexpr int VPT = Constants::VPT;
  340. static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
  341. const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
  342. const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
  343. dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
  344. topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
  345. input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
  346. }
  347. #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
  348. topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
  349. gating_output, nullptr, topk_weights, topk_indicies, \
  350. token_expert_indices, num_tokens, topk, 0, num_experts, \
  351. stream);
  352. void topkGatingSoftmaxKernelLauncher(
  353. const float* gating_output,
  354. float* topk_weights,
  355. int* topk_indicies,
  356. int* token_expert_indices,
  357. float* softmax_workspace,
  358. const int num_tokens,
  359. const int num_experts,
  360. const int topk,
  361. cudaStream_t stream) {
  362. static constexpr int WARPS_PER_TB = 4;
  363. switch (num_experts) {
  364. case 1:
  365. LAUNCH_SOFTMAX(1, WARPS_PER_TB);
  366. break;
  367. case 2:
  368. LAUNCH_SOFTMAX(2, WARPS_PER_TB);
  369. break;
  370. case 4:
  371. LAUNCH_SOFTMAX(4, WARPS_PER_TB);
  372. break;
  373. case 8:
  374. LAUNCH_SOFTMAX(8, WARPS_PER_TB);
  375. break;
  376. case 16:
  377. LAUNCH_SOFTMAX(16, WARPS_PER_TB);
  378. break;
  379. case 32:
  380. LAUNCH_SOFTMAX(32, WARPS_PER_TB);
  381. break;
  382. case 64:
  383. LAUNCH_SOFTMAX(64, WARPS_PER_TB);
  384. break;
  385. case 128:
  386. LAUNCH_SOFTMAX(128, WARPS_PER_TB);
  387. break;
  388. case 256:
  389. LAUNCH_SOFTMAX(256, WARPS_PER_TB);
  390. break;
  391. default: {
  392. TORCH_CHECK(softmax_workspace != nullptr,
  393. "softmax_workspace must be provided for num_experts that are not a power of 2.");
  394. static constexpr int TPB = 256;
  395. moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
  396. gating_output, nullptr, softmax_workspace, num_experts);
  397. moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
  398. softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
  399. num_experts, topk, 0, num_experts);
  400. }
  401. }
  402. }
  403. } // namespace moe
  404. } // namespace aphrodite
  405. void topk_softmax(
  406. torch::Tensor& topk_weights, // [num_tokens, topk]
  407. torch::Tensor& topk_indices, // [num_tokens, topk]
  408. torch::Tensor& token_expert_indices, // [num_tokens, topk]
  409. torch::Tensor& gating_output) // [num_tokens, num_experts]
  410. {
  411. const int num_experts = gating_output.size(-1);
  412. const int num_tokens = gating_output.numel() / num_experts;
  413. const int topk = topk_weights.size(-1);
  414. const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
  415. const bool needs_workspace = !is_pow_2 || num_experts > 256;
  416. const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
  417. const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
  418. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  419. torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
  420. aphrodite::moe::topkGatingSoftmaxKernelLauncher(
  421. gating_output.data_ptr<float>(),
  422. topk_weights.data_ptr<float>(),
  423. topk_indices.data_ptr<int>(),
  424. token_expert_indices.data_ptr<int>(),
  425. softmax_workspace.data_ptr<float>(),
  426. num_tokens,
  427. num_experts,
  428. topk,
  429. stream);
  430. }