pos_encoding_kernels.cu 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. #include <torch/all.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "cuda_compat.h"
  5. #include "dispatch_utils.h"
  6. namespace aphrodite {
  7. template <typename scalar_t, bool IS_NEOX>
  8. inline __device__ void apply_token_rotary_embedding(
  9. scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
  10. const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
  11. int x_index, y_index;
  12. scalar_t cos, sin;
  13. if (IS_NEOX) {
  14. // GPT-NeoX style rotary embedding.
  15. x_index = rot_offset;
  16. y_index = embed_dim + rot_offset;
  17. cos = APHRODITE_LDG(cos_ptr + x_index);
  18. sin = APHRODITE_LDG(sin_ptr + x_index);
  19. } else {
  20. // GPT-J style rotary embedding.
  21. x_index = 2 * rot_offset;
  22. y_index = 2 * rot_offset + 1;
  23. cos = APHRODITE_LDG(cos_ptr + x_index / 2);
  24. sin = APHRODITE_LDG(sin_ptr + x_index / 2);
  25. }
  26. const scalar_t x = arr[x_index];
  27. const scalar_t y = arr[y_index];
  28. arr[x_index] = x * cos - y * sin;
  29. arr[y_index] = y * cos + x * sin;
  30. }
  31. template <typename scalar_t, bool IS_NEOX>
  32. inline __device__ void apply_rotary_embedding(
  33. scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
  34. // head_size] or [num_tokens, num_heads,
  35. // head_size]
  36. scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
  37. // head_size] or [num_tokens, num_kv_heads,
  38. // head_size]
  39. const scalar_t* cache_ptr, const int head_size, const int num_heads,
  40. const int num_kv_heads, const int rot_dim, const int token_idx,
  41. const int64_t query_stride, const int64_t key_stride) {
  42. const int embed_dim = rot_dim / 2;
  43. const scalar_t* cos_ptr = cache_ptr;
  44. const scalar_t* sin_ptr = cache_ptr + embed_dim;
  45. const int nq = num_heads * embed_dim;
  46. for (int i = threadIdx.x; i < nq; i += blockDim.x) {
  47. const int head_idx = i / embed_dim;
  48. const int64_t token_head = token_idx * query_stride + head_idx * head_size;
  49. const int rot_offset = i % embed_dim;
  50. apply_token_rotary_embedding<scalar_t, IS_NEOX>(
  51. query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
  52. }
  53. const int nk = num_kv_heads * embed_dim;
  54. for (int i = threadIdx.x; i < nk; i += blockDim.x) {
  55. const int head_idx = i / embed_dim;
  56. const int64_t token_head = token_idx * key_stride + head_idx * head_size;
  57. const int rot_offset = i % embed_dim;
  58. apply_token_rotary_embedding<scalar_t, IS_NEOX>(
  59. key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
  60. }
  61. }
  62. template <typename scalar_t, bool IS_NEOX>
  63. __global__ void rotary_embedding_kernel(
  64. const int64_t* __restrict__ positions, // [batch_size, seq_len] or
  65. // [num_tokens]
  66. scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
  67. // head_size] or [num_tokens, num_heads,
  68. // head_size]
  69. scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
  70. // head_size] or [num_tokens, num_kv_heads,
  71. // head_size]
  72. const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
  73. // 2]
  74. const int rot_dim, const int64_t query_stride, const int64_t key_stride,
  75. const int num_heads, const int num_kv_heads, const int head_size) {
  76. // Each thread block is responsible for one token.
  77. const int token_idx = blockIdx.x;
  78. int64_t pos = positions[token_idx];
  79. const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
  80. apply_rotary_embedding<scalar_t, IS_NEOX>(
  81. query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
  82. token_idx, query_stride, key_stride);
  83. }
  84. template <typename scalar_t, bool IS_NEOX>
  85. __global__ void batched_rotary_embedding_kernel(
  86. const int64_t* __restrict__ positions, // [batch_size, seq_len] or
  87. // [num_tokens]
  88. scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
  89. // head_size] or [num_tokens, num_heads,
  90. // head_size]
  91. scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
  92. // head_size] or [num_tokens, num_kv_heads,
  93. // head_size]
  94. const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
  95. // 2]
  96. const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
  97. // or [num_tokens]
  98. const int rot_dim, const int64_t query_stride, const int64_t key_stride,
  99. const int num_heads, const int num_kv_heads, const int head_size) {
  100. // Each thread block is responsible for one token.
  101. const int token_idx = blockIdx.x;
  102. int64_t pos = positions[token_idx];
  103. int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
  104. const scalar_t* cache_ptr =
  105. cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
  106. apply_rotary_embedding<scalar_t, IS_NEOX>(
  107. query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
  108. token_idx, query_stride, key_stride);
  109. }
  110. } // namespace aphrodite
  111. void rotary_embedding(
  112. torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
  113. torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
  114. // [num_tokens, num_heads * head_size]
  115. torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
  116. // [num_tokens, num_kv_heads * head_size]
  117. int64_t head_size,
  118. torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
  119. bool is_neox) {
  120. int64_t num_tokens = query.numel() / query.size(-1);
  121. int rot_dim = cos_sin_cache.size(1);
  122. int num_heads = query.size(-1) / head_size;
  123. int num_kv_heads = key.size(-1) / head_size;
  124. int64_t query_stride = query.stride(-2);
  125. int64_t key_stride = key.stride(-2);
  126. dim3 grid(num_tokens);
  127. dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
  128. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  129. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  130. APHRODITE_DISPATCH_FLOATING_TYPES(
  131. query.scalar_type(), "rotary_embedding", [&] {
  132. if (is_neox) {
  133. aphrodite::rotary_embedding_kernel<scalar_t, true>
  134. <<<grid, block, 0, stream>>>(
  135. positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
  136. key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
  137. rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
  138. head_size);
  139. } else {
  140. aphrodite::rotary_embedding_kernel<scalar_t, false>
  141. <<<grid, block, 0, stream>>>(
  142. positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
  143. key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
  144. rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
  145. head_size);
  146. }
  147. });
  148. }
  149. /*
  150. Batched version of rotary embedding, pack multiple LoRAs together
  151. and process in batched manner.
  152. */
  153. void batched_rotary_embedding(
  154. torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
  155. torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
  156. // [num_tokens, num_heads * head_size]
  157. torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
  158. // [num_tokens, num_kv_heads * head_size]
  159. int64_t head_size,
  160. torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
  161. bool is_neox, int64_t rot_dim,
  162. torch::Tensor& cos_sin_cache_offsets // [num_tokens]
  163. ) {
  164. int64_t num_tokens = cos_sin_cache_offsets.size(0);
  165. int num_heads = query.size(-1) / head_size;
  166. int num_kv_heads = key.size(-1) / head_size;
  167. int64_t query_stride = query.stride(-2);
  168. int64_t key_stride = key.stride(-2);
  169. dim3 grid(num_tokens);
  170. dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
  171. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  172. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  173. APHRODITE_DISPATCH_FLOATING_TYPES(
  174. query.scalar_type(), "rotary_embedding", [&] {
  175. if (is_neox) {
  176. aphrodite::batched_rotary_embedding_kernel<scalar_t, true>
  177. <<<grid, block, 0, stream>>>(
  178. positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
  179. key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
  180. cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim,
  181. query_stride, key_stride, num_heads, num_kv_heads, head_size);
  182. } else {
  183. aphrodite::batched_rotary_embedding_kernel<scalar_t, false>
  184. <<<grid, block, 0, stream>>>(
  185. positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
  186. key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
  187. cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim,
  188. query_stride, key_stride, num_heads, num_kv_heads, head_size);
  189. }
  190. });
  191. }