pos_encoding_kernels.cu 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include "dispatch_utils.h"
  4. namespace aphrodite {
  5. template<typename scalar_t, bool IS_NEOX>
  6. inline __device__ void apply_rotary_embedding(
  7. scalar_t* __restrict__ arr,
  8. const scalar_t* __restrict__ cos_ptr,
  9. const scalar_t* __restrict__ sin_ptr,
  10. int rot_offset,
  11. int embed_dim)
  12. {
  13. int x_index, y_index;
  14. scalar_t cos, sin;
  15. if (IS_NEOX) {
  16. // GPT-NeoX style rotary embedding.
  17. x_index = rot_offset;
  18. y_index = embed_dim + rot_offset;
  19. cos = __ldg(cos_ptr + x_index);
  20. sin = __ldg(sin_ptr + x_index);
  21. } else {
  22. // GPT-J style rotary embedding.
  23. x_index = 2 * rot_offset;
  24. y_index = 2 * rot_offset + 1;
  25. cos = __ldg(cos_ptr + x_index / 2);
  26. sin = __ldg(sin_ptr + x_index / 2);
  27. }
  28. const scalar_t x = arr[x_index];
  29. const scalar_t y = arr[y_index];
  30. arr[x_index] = x * cos - y * sin;
  31. arr[y_index] = y * cos + x * sin;
  32. }
  33. template<typename scalar_t, bool IS_NEOX>
  34. __global__ void rotary_embedding_kernel(
  35. const int64_t* __restrict__ positions, // [num_tokens]
  36. scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
  37. scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
  38. const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
  39. const int rot_dim,
  40. const int query_stride,
  41. const int key_stride,
  42. const int num_heads,
  43. const int num_kv_heads,
  44. const int head_size) {
  45. // Each thread block is responsible for one token.
  46. const int token_idx = blockIdx.x;
  47. int64_t pos = positions[token_idx];
  48. const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
  49. const int embed_dim = rot_dim / 2;
  50. const scalar_t* cos_ptr = cache_ptr;
  51. const scalar_t* sin_ptr = cache_ptr + embed_dim;
  52. const int nq = num_heads * embed_dim;
  53. for (int i = threadIdx.x; i < nq; i += blockDim.x) {
  54. const int head_idx = i / embed_dim;
  55. const int token_head = token_idx * query_stride + head_idx * head_size;
  56. const int rot_offset = i % embed_dim;
  57. apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
  58. sin_ptr, rot_offset, embed_dim);
  59. }
  60. const int nk = num_kv_heads * embed_dim;
  61. for (int i = threadIdx.x; i < nk; i += blockDim.x) {
  62. const int head_idx = i / embed_dim;
  63. const int token_head = token_idx * key_stride + head_idx * head_size;
  64. const int rot_offset = i % embed_dim;
  65. apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
  66. sin_ptr, rot_offset, embed_dim);
  67. }
  68. }
  69. } // namespace aphrodite
  70. void rotary_embedding(
  71. torch::Tensor& positions, // [num_tokens]
  72. torch::Tensor& query, // [num_tokens, num_heads * head_size]
  73. torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
  74. int head_size,
  75. torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
  76. bool is_neox) {
  77. int num_tokens = query.size(0);
  78. int rot_dim = cos_sin_cache.size(1);
  79. int num_heads = query.size(1) / head_size;
  80. int num_kv_heads = key.size(1) / head_size;
  81. int query_stride = query.stride(0);
  82. int key_stride = key.stride(0);
  83. dim3 grid(num_tokens);
  84. dim3 block(std::min(num_heads * rot_dim / 2, 512));
  85. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  86. APHRODITE_DISPATCH_FLOATING_TYPES(
  87. query.scalar_type(),
  88. "rotary_embedding",
  89. [&] {
  90. if (is_neox) {
  91. aphrodite::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
  92. positions.data_ptr<int64_t>(),
  93. query.data_ptr<scalar_t>(),
  94. key.data_ptr<scalar_t>(),
  95. cos_sin_cache.data_ptr<scalar_t>(),
  96. rot_dim,
  97. query_stride,
  98. key_stride,
  99. num_heads,
  100. num_kv_heads,
  101. head_size);
  102. } else {
  103. aphrodite::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
  104. positions.data_ptr<int64_t>(),
  105. query.data_ptr<scalar_t>(),
  106. key.data_ptr<scalar_t>(),
  107. cos_sin_cache.data_ptr<scalar_t>(),
  108. rot_dim,
  109. query_stride,
  110. key_stride,
  111. num_heads,
  112. num_kv_heads,
  113. head_size);
  114. }
  115. });
  116. }