pos_encoding_kernels.cu 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "cuda_compat.h"
  5. #include "dispatch_utils.h"
  6. namespace aphrodite {
  7. template<typename scalar_t, bool IS_NEOX>
  8. inline __device__ void apply_rotary_embedding(
  9. scalar_t* __restrict__ arr,
  10. const scalar_t* __restrict__ cos_ptr,
  11. const scalar_t* __restrict__ sin_ptr,
  12. int rot_offset,
  13. int embed_dim)
  14. {
  15. int x_index, y_index;
  16. scalar_t cos, sin;
  17. if (IS_NEOX) {
  18. // GPT-NeoX style rotary embedding.
  19. x_index = rot_offset;
  20. y_index = embed_dim + rot_offset;
  21. cos = APHRODITE_LDG(cos_ptr + x_index);
  22. sin = APHRODITE_LDG(sin_ptr + x_index);
  23. } else {
  24. // GPT-J style rotary embedding.
  25. x_index = 2 * rot_offset;
  26. y_index = 2 * rot_offset + 1;
  27. cos = APHRODITE_LDG(cos_ptr + x_index / 2);
  28. sin = APHRODITE_LDG(sin_ptr + x_index / 2);
  29. }
  30. const scalar_t x = arr[x_index];
  31. const scalar_t y = arr[y_index];
  32. arr[x_index] = x * cos - y * sin;
  33. arr[y_index] = y * cos + x * sin;
  34. }
  35. template<typename scalar_t, bool IS_NEOX>
  36. __global__ void rotary_embedding_kernel(
  37. const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
  38. scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
  39. scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
  40. const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
  41. const int rot_dim,
  42. const int64_t query_stride,
  43. const int64_t key_stride,
  44. const int num_heads,
  45. const int num_kv_heads,
  46. const int head_size) {
  47. // Each thread block is responsible for one token.
  48. const int token_idx = blockIdx.x;
  49. int64_t pos = positions[token_idx];
  50. const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
  51. const int embed_dim = rot_dim / 2;
  52. const scalar_t* cos_ptr = cache_ptr;
  53. const scalar_t* sin_ptr = cache_ptr + embed_dim;
  54. const int nq = num_heads * embed_dim;
  55. for (int i = threadIdx.x; i < nq; i += blockDim.x) {
  56. const int head_idx = i / embed_dim;
  57. const int64_t token_head = token_idx * query_stride + head_idx * head_size;
  58. const int rot_offset = i % embed_dim;
  59. apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
  60. sin_ptr, rot_offset, embed_dim);
  61. }
  62. const int nk = num_kv_heads * embed_dim;
  63. for (int i = threadIdx.x; i < nk; i += blockDim.x) {
  64. const int head_idx = i / embed_dim;
  65. const int64_t token_head = token_idx * key_stride + head_idx * head_size;
  66. const int rot_offset = i % embed_dim;
  67. apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
  68. sin_ptr, rot_offset, embed_dim);
  69. }
  70. }
  71. } // namespace aphrodite
  72. void rotary_embedding(
  73. torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
  74. torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
  75. torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
  76. int head_size,
  77. torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
  78. bool is_neox) {
  79. int64_t num_tokens = query.numel() / query.size(-1);
  80. int rot_dim = cos_sin_cache.size(1);
  81. int num_heads = query.size(-1) / head_size;
  82. int num_kv_heads = key.size(-1) / head_size;
  83. int64_t query_stride = query.stride(-2);
  84. int64_t key_stride = key.stride(-2);
  85. dim3 grid(num_tokens);
  86. dim3 block(std::min(num_heads * rot_dim / 2, 512));
  87. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  88. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  89. APHRODITE_DISPATCH_FLOATING_TYPES(
  90. query.scalar_type(),
  91. "rotary_embedding",
  92. [&] {
  93. if (is_neox) {
  94. aphrodite::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
  95. positions.data_ptr<int64_t>(),
  96. query.data_ptr<scalar_t>(),
  97. key.data_ptr<scalar_t>(),
  98. cos_sin_cache.data_ptr<scalar_t>(),
  99. rot_dim,
  100. query_stride,
  101. key_stride,
  102. num_heads,
  103. num_kv_heads,
  104. head_size);
  105. } else {
  106. aphrodite::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
  107. positions.data_ptr<int64_t>(),
  108. query.data_ptr<scalar_t>(),
  109. key.data_ptr<scalar_t>(),
  110. cos_sin_cache.data_ptr<scalar_t>(),
  111. rot_dim,
  112. query_stride,
  113. key_stride,
  114. num_heads,
  115. num_kv_heads,
  116. head_size);
  117. }
  118. });
  119. }