align_block_size_kernel.cu 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. #include <torch/all.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <ATen/ATen.h>
  4. #include <THC/THCAtomics.cuh>
  5. #include "../cuda_compat.h"
  6. #include "../dispatch_utils.h"
  7. #define CEILDIV(x, y) (((x) + (y) - 1) / (y))
  8. namespace aphrodite {
  9. namespace {
  10. __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
  11. int32_t col) {
  12. // don't worry about overflow because num_experts is relatively small
  13. return row * total_col + col;
  14. }
  15. } // namespace
  16. template <typename scalar_t>
  17. __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
  18. int32_t* sorted_token_ids,
  19. int32_t* expert_ids,
  20. int32_t* total_tokens_post_pad,
  21. int32_t num_experts,
  22. int32_t block_size, size_t numel) {
  23. const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  24. const size_t start_idx = threadIdx.x * tokens_per_thread;
  25. extern __shared__ int32_t shared_mem[];
  26. int32_t* tokens_cnts =
  27. shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
  28. int32_t* cumsum =
  29. shared_mem + (num_experts + 1) *
  30. num_experts; // 1d tensor with shape (num_experts + 1)
  31. for (int i = 0; i < num_experts; ++i) {
  32. tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
  33. }
  34. /**
  35. * In the first step we compute token_cnts[thread_index + 1][expert_index],
  36. * which counts how many tokens in the token shard of thread_index are
  37. * assigned to expert expert_index.
  38. */
  39. for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
  40. ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
  41. }
  42. __syncthreads();
  43. // For each expert we accumulate the token counts from the different threads.
  44. tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
  45. for (int i = 1; i <= blockDim.x; ++i) {
  46. tokens_cnts[index(num_experts, i, threadIdx.x)] +=
  47. tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
  48. }
  49. __syncthreads();
  50. // We accumulate the token counts of all experts in thread 0.
  51. if (threadIdx.x == 0) {
  52. cumsum[0] = 0;
  53. for (int i = 1; i <= num_experts; ++i) {
  54. cumsum[i] = cumsum[i - 1] +
  55. CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
  56. block_size) *
  57. block_size;
  58. }
  59. *total_tokens_post_pad = cumsum[num_experts];
  60. }
  61. __syncthreads();
  62. /**
  63. * For each expert, each thread processes the tokens of the corresponding
  64. * blocks and stores the corresponding expert_id for each block.
  65. */
  66. for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
  67. i += block_size) {
  68. expert_ids[i / block_size] = threadIdx.x;
  69. }
  70. /**
  71. * Each thread processes a token shard, calculating the index of each token
  72. * after sorting by expert number. Given the example topk_ids =
  73. * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
  74. * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
  75. * padding value(preset in python).
  76. */
  77. for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
  78. int32_t expert_id = topk_ids[i];
  79. /** The cumsum[expert_id] stores the starting index of the tokens that the
  80. * expert with expert_id needs to process, and
  81. * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
  82. * processed by the expert with expert_id within the current thread's token
  83. * shard.
  84. */
  85. int32_t rank_post_pad =
  86. tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
  87. cumsum[expert_id];
  88. sorted_token_ids[rank_post_pad] = i;
  89. ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
  90. }
  91. }
  92. // TODO: this is temporarily adapted from
  93. // https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
  94. // we did this to unblock Deepseek V3 but there should be a better
  95. // implementation to manage shared memory.
  96. template <typename scalar_t>
  97. __global__ void moe_align_block_size_global_mem_kernel(
  98. scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
  99. int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
  100. int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) {
  101. const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  102. const size_t start_idx = threadIdx.x * tokens_per_thread;
  103. for (int i = 0; i < num_experts; ++i) {
  104. tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
  105. }
  106. /**
  107. * In the first step we compute token_cnts[thread_index + 1][expert_index],
  108. * which counts how many tokens in the token shard of thread_index are
  109. * assigned to expert expert_index.
  110. */
  111. for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
  112. ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
  113. }
  114. __syncthreads();
  115. // For each expert we accumulate the token counts from the different threads.
  116. if (threadIdx.x < num_experts) {
  117. tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
  118. for (int i = 1; i <= blockDim.x; ++i) {
  119. tokens_cnts[index(num_experts, i, threadIdx.x)] +=
  120. tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
  121. }
  122. }
  123. __syncthreads();
  124. // We accumulate the token counts of all experts in thread 0.
  125. if (threadIdx.x == 0) {
  126. cumsum[0] = 0;
  127. for (int i = 1; i <= num_experts; ++i) {
  128. cumsum[i] = cumsum[i - 1] +
  129. CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
  130. block_size) *
  131. block_size;
  132. }
  133. *total_tokens_post_pad = cumsum[num_experts];
  134. }
  135. __syncthreads();
  136. /**
  137. * For each expert, each thread processes the tokens of the corresponding
  138. * blocks and stores the corresponding expert_id for each block.
  139. */
  140. if (threadIdx.x < num_experts) {
  141. for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
  142. i += block_size) {
  143. expert_ids[i / block_size] = threadIdx.x;
  144. }
  145. }
  146. /**
  147. * Each thread processes a token shard, calculating the index of each token
  148. * after sorting by expert number. Given the example topk_ids =
  149. * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
  150. * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
  151. * padding value(preset in python).
  152. */
  153. for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
  154. int32_t expert_id = topk_ids[i];
  155. /** The cumsum[expert_id] stores the starting index of the tokens that the
  156. * expert with expert_id needs to process, and
  157. * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
  158. * processed by the expert with expert_id within the current thread's token
  159. * shard.
  160. */
  161. int32_t rank_post_pad =
  162. tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
  163. cumsum[expert_id];
  164. sorted_token_ids[rank_post_pad] = i;
  165. ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
  166. }
  167. }
  168. } // namespace aphrodite
  169. void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
  170. int64_t block_size, torch::Tensor sorted_token_ids,
  171. torch::Tensor experts_ids,
  172. torch::Tensor num_tokens_post_pad) {
  173. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  174. // If we have very large number of experts, we can no longer use shared
  175. // memory.
  176. // TODO: the right solution should be calculating the exact right
  177. // amount of shared memory and use that. The num_experts >= 256 is just a
  178. // temporary solution to unblock Deepseek V3.
  179. if (num_experts >= 256) {
  180. APHRODITE_DISPATCH_INTEGRAL_TYPES(
  181. topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
  182. // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
  183. // tensors
  184. const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
  185. const int32_t mem_tokens_cnts =
  186. ((num_experts + 1) * num_experts) * sizeof(int32_t);
  187. const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
  188. // allocate global memory
  189. int32_t* tokens_cnts;
  190. int32_t* cumsum;
  191. cudaMalloc(&tokens_cnts, mem_tokens_cnts);
  192. cudaMalloc(&cumsum, mem_cumsum);
  193. auto kernel =
  194. aphrodite::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
  195. kernel<<<1, num_thread, 0, stream>>>(
  196. topk_ids.data_ptr<scalar_t>(),
  197. sorted_token_ids.data_ptr<int32_t>(),
  198. experts_ids.data_ptr<int32_t>(),
  199. num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
  200. topk_ids.numel(), tokens_cnts, cumsum);
  201. cudaFree(tokens_cnts);
  202. cudaFree(cumsum);
  203. });
  204. } else {
  205. APHRODITE_DISPATCH_INTEGRAL_TYPES(
  206. topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
  207. // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
  208. // tensors
  209. const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
  210. const int32_t shared_mem =
  211. ((num_thread + 1) * num_experts + (num_experts + 1)) *
  212. sizeof(int32_t);
  213. // set dynamic shared mem
  214. auto kernel = aphrodite::moe::moe_align_block_size_kernel<scalar_t>;
  215. AT_CUDA_CHECK(
  216. APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
  217. (void*)kernel, shared_mem));
  218. kernel<<<1, num_thread, shared_mem, stream>>>(
  219. topk_ids.data_ptr<scalar_t>(),
  220. sorted_token_ids.data_ptr<int32_t>(),
  221. experts_ids.data_ptr<int32_t>(),
  222. num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
  223. topk_ids.numel());
  224. });
  225. }
  226. }