align_block_size_kernel.cu 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <ATen/ATen.h>
  4. #include <THC/THCAtomics.cuh>
  5. #include "../cuda_compat.h"
  6. #include "../dispatch_utils.h"
  7. #define CEILDIV(x,y) (((x) + (y) - 1) / (y))
  8. namespace aphrodite {
  9. namespace {
  10. __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
  11. // don't worry about overflow because num_experts is relatively small
  12. return row * total_col + col;
  13. }
  14. }
  15. template <typename scalar_t>
  16. __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
  17. int32_t *sorted_token_ids,
  18. int32_t *expert_ids,
  19. int32_t *total_tokens_post_pad,
  20. int32_t num_experts,
  21. int32_t block_size,
  22. size_t numel) {
  23. const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  24. const size_t start_idx = threadIdx.x * tokens_per_thread;
  25. extern __shared__ int32_t shared_mem[];
  26. int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
  27. int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
  28. for (int i = 0; i < num_experts; ++i) {
  29. tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
  30. }
  31. /**
  32. * In the first step we compute token_cnts[thread_index + 1][expert_index],
  33. * which counts how many tokens in the token shard of thread_index are assigned
  34. * to expert expert_index.
  35. */
  36. for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
  37. ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
  38. }
  39. __syncthreads();
  40. // For each expert we accumulate the token counts from the different threads.
  41. tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
  42. for (int i = 1; i <= blockDim.x; ++i) {
  43. tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
  44. }
  45. __syncthreads();
  46. // We accumulate the token counts of all experts in thread 0.
  47. if (threadIdx.x == 0) {
  48. cumsum[0] = 0;
  49. for (int i = 1; i <= num_experts; ++i) {
  50. cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
  51. }
  52. *total_tokens_post_pad = cumsum[num_experts];
  53. }
  54. __syncthreads();
  55. /**
  56. * For each expert, each thread processes the tokens of the corresponding blocks
  57. * and stores the corresponding expert_id for each block.
  58. */
  59. for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
  60. expert_ids[i / block_size] = threadIdx.x;
  61. }
  62. /**
  63. * Each thread processes a token shard, calculating the index of each token after
  64. * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
  65. * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
  66. * where * represents a padding value(preset in python).
  67. */
  68. for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
  69. int32_t expert_id = topk_ids[i];
  70. /** The cumsum[expert_id] stores the starting index of the tokens that the
  71. * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
  72. * stores the indices of the tokens processed by the expert with expert_id within
  73. * the current thread's token shard.
  74. */
  75. int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
  76. sorted_token_ids[rank_post_pad] = i;
  77. ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
  78. }
  79. }
  80. }
  81. void moe_align_block_size(
  82. torch::Tensor topk_ids,
  83. int num_experts,
  84. int block_size,
  85. torch::Tensor sorted_token_ids,
  86. torch::Tensor experts_ids,
  87. torch::Tensor num_tokens_post_pad) {
  88. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  89. APHRODITE_DISPATCH_INTEGRAL_TYPES(
  90. topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
  91. // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
  92. const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
  93. // set dynamic shared mem
  94. auto kernel = aphrodite::moe_align_block_size_kernel<scalar_t>;
  95. AT_CUDA_CHECK(
  96. APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem));
  97. kernel<<<1, num_experts, shared_mem, stream>>>(
  98. topk_ids.data_ptr<scalar_t>(),
  99. sorted_token_ids.data_ptr<int32_t>(),
  100. experts_ids.data_ptr<int32_t>(),
  101. num_tokens_post_pad.data_ptr<int32_t>(),
  102. num_experts,
  103. block_size,
  104. topk_ids.numel());
  105. });
  106. }