align_block_size_kernel.cu 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <ATen/ATen.h>
  4. #include <THC/THCAtomics.cuh>
  5. #include "../cuda_compat.h"
  6. #include "../dispatch_utils.h"
  7. const static size_t NUM_MAX_EXPERTS = 64;
  8. #define CEILDIV(x,y) (((x) + (y) - 1) / (y))
  9. namespace aphrodite {
  10. template <typename scalar_t>
  11. __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
  12. int32_t *sorted_token_ids,
  13. int32_t *expert_ids,
  14. int32_t *total_tokens_post_pad,
  15. int32_t num_experts,
  16. int32_t block_size,
  17. size_t numel) {
  18. const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  19. const size_t start_idx = threadIdx.x * tokens_per_thread;
  20. __shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
  21. __shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
  22. for (int i = 0; i < num_experts; ++i) {
  23. tokens_cnts[threadIdx.x + 1][i] = 0;
  24. }
  25. /**
  26. * In the first step we compute token_cnts[thread_index + 1][expert_index],
  27. * which counts how many tokens in the token shard of thread_index are assigned
  28. * to expert expert_index.
  29. */
  30. for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
  31. ++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
  32. }
  33. __syncthreads();
  34. // For each expert we accumulate the token counts from the different threads.
  35. tokens_cnts[0][threadIdx.x] = 0;
  36. for (int i = 1; i <= blockDim.x; ++i) {
  37. tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
  38. }
  39. __syncthreads();
  40. // We accumulate the token counts of all experts in thread 0.
  41. if (threadIdx.x == 0) {
  42. cumsum[0] = 0;
  43. for (int i = 1; i <= num_experts; ++i) {
  44. cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
  45. }
  46. *total_tokens_post_pad = cumsum[num_experts];
  47. }
  48. __syncthreads();
  49. /**
  50. * For each expert, each thread processes the tokens of the corresponding blocks
  51. * and stores the corresponding expert_id for each block.
  52. */
  53. for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
  54. expert_ids[i / block_size] = threadIdx.x;
  55. }
  56. /**
  57. * Each thread processes a token shard, calculating the index of each token after
  58. * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
  59. * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
  60. * where * represents a padding value(preset in python).
  61. */
  62. for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
  63. int32_t expert_id = topk_ids[i];
  64. /** The cumsum[expert_id] stores the starting index of the tokens that the
  65. * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
  66. * stores the indices of the tokens processed by the expert with expert_id within
  67. * the current thread's token shard.
  68. */
  69. int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
  70. sorted_token_ids[rank_post_pad] = i;
  71. ++tokens_cnts[threadIdx.x][expert_id];
  72. }
  73. }
  74. }
  75. void moe_align_block_size(
  76. torch::Tensor topk_ids,
  77. int num_experts,
  78. int block_size,
  79. torch::Tensor sorted_token_ids,
  80. torch::Tensor experts_ids,
  81. torch::Tensor num_tokens_post_pad) {
  82. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  83. assert(num_experts <= NUM_MAX_EXPERTS);
  84. APHRODITE_DISPATCH_INTEGRAL_TYPES(
  85. topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
  86. aphrodite::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
  87. topk_ids.data_ptr<scalar_t>(),
  88. sorted_token_ids.data_ptr<int32_t>(),
  89. experts_ids.data_ptr<int32_t>(),
  90. num_tokens_post_pad.data_ptr<int32_t>(),
  91. num_experts,
  92. block_size,
  93. topk_ids.numel());
  94. });
  95. }