pos_encoding_kernels.cu 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "cuda_compat.h"
  5. #include "dispatch_utils.h"
  6. namespace aphrodite {
  7. template<typename scalar_t, bool IS_NEOX>
  8. inline __device__ void apply_token_rotary_embedding(
  9. scalar_t* __restrict__ arr,
  10. const scalar_t* __restrict__ cos_ptr,
  11. const scalar_t* __restrict__ sin_ptr,
  12. int rot_offset,
  13. int embed_dim)
  14. {
  15. int x_index, y_index;
  16. scalar_t cos, sin;
  17. if (IS_NEOX) {
  18. // GPT-NeoX style rotary embedding.
  19. x_index = rot_offset;
  20. y_index = embed_dim + rot_offset;
  21. cos = APHRODITE_LDG(cos_ptr + x_index);
  22. sin = APHRODITE_LDG(sin_ptr + x_index);
  23. } else {
  24. // GPT-J style rotary embedding.
  25. x_index = 2 * rot_offset;
  26. y_index = 2 * rot_offset + 1;
  27. cos = APHRODITE_LDG(cos_ptr + x_index / 2);
  28. sin = APHRODITE_LDG(sin_ptr + x_index / 2);
  29. }
  30. const scalar_t x = arr[x_index];
  31. const scalar_t y = arr[y_index];
  32. arr[x_index] = x * cos - y * sin;
  33. arr[y_index] = y * cos + x * sin;
  34. }
  35. template<typename scalar_t, bool IS_NEOX>
  36. inline __device__ void apply_rotary_embedding(
  37. scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
  38. scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
  39. const scalar_t* cache_ptr,
  40. const int head_size,
  41. const int num_heads,
  42. const int num_kv_heads,
  43. const int rot_dim,
  44. const int token_idx,
  45. const int64_t query_stride,
  46. const int64_t key_stride)
  47. {
  48. const int embed_dim = rot_dim / 2;
  49. const scalar_t* cos_ptr = cache_ptr;
  50. const scalar_t* sin_ptr = cache_ptr + embed_dim;
  51. const int nq = num_heads * embed_dim;
  52. for (int i = threadIdx.x; i < nq; i += blockDim.x) {
  53. const int head_idx = i / embed_dim;
  54. const int64_t token_head = token_idx * query_stride + head_idx * head_size;
  55. const int rot_offset = i % embed_dim;
  56. apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
  57. sin_ptr, rot_offset, embed_dim);
  58. }
  59. const int nk = num_kv_heads * embed_dim;
  60. for (int i = threadIdx.x; i < nk; i += blockDim.x) {
  61. const int head_idx = i / embed_dim;
  62. const int64_t token_head = token_idx * key_stride + head_idx * head_size;
  63. const int rot_offset = i % embed_dim;
  64. apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
  65. sin_ptr, rot_offset, embed_dim);
  66. }
  67. }
  68. template<typename scalar_t, bool IS_NEOX>
  69. __global__ void rotary_embedding_kernel(
  70. const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
  71. scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
  72. scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
  73. const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
  74. const int rot_dim,
  75. const int64_t query_stride,
  76. const int64_t key_stride,
  77. const int num_heads,
  78. const int num_kv_heads,
  79. const int head_size) {
  80. // Each thread block is responsible for one token.
  81. const int token_idx = blockIdx.x;
  82. int64_t pos = positions[token_idx];
  83. const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
  84. apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
  85. }
  86. template<typename scalar_t, bool IS_NEOX>
  87. __global__ void batched_rotary_embedding_kernel(
  88. const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
  89. scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
  90. scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
  91. const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
  92. const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
  93. const int rot_dim,
  94. const int64_t query_stride,
  95. const int64_t key_stride,
  96. const int num_heads,
  97. const int num_kv_heads,
  98. const int head_size) {
  99. // Each thread block is responsible for one token.
  100. const int token_idx = blockIdx.x;
  101. int64_t pos = positions[token_idx];
  102. int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
  103. const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
  104. apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
  105. }
  106. } // namespace aphrodite
  107. void rotary_embedding(
  108. torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
  109. torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
  110. torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
  111. int head_size,
  112. torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
  113. bool is_neox) {
  114. int64_t num_tokens = query.numel() / query.size(-1);
  115. int rot_dim = cos_sin_cache.size(1);
  116. int num_heads = query.size(-1) / head_size;
  117. int num_kv_heads = key.size(-1) / head_size;
  118. int64_t query_stride = query.stride(-2);
  119. int64_t key_stride = key.stride(-2);
  120. dim3 grid(num_tokens);
  121. dim3 block(std::min(num_heads * rot_dim / 2, 512));
  122. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  123. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  124. APHRODITE_DISPATCH_FLOATING_TYPES(
  125. query.scalar_type(),
  126. "rotary_embedding",
  127. [&] {
  128. if (is_neox) {
  129. aphrodite::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
  130. positions.data_ptr<int64_t>(),
  131. query.data_ptr<scalar_t>(),
  132. key.data_ptr<scalar_t>(),
  133. cos_sin_cache.data_ptr<scalar_t>(),
  134. rot_dim,
  135. query_stride,
  136. key_stride,
  137. num_heads,
  138. num_kv_heads,
  139. head_size);
  140. } else {
  141. aphrodite::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
  142. positions.data_ptr<int64_t>(),
  143. query.data_ptr<scalar_t>(),
  144. key.data_ptr<scalar_t>(),
  145. cos_sin_cache.data_ptr<scalar_t>(),
  146. rot_dim,
  147. query_stride,
  148. key_stride,
  149. num_heads,
  150. num_kv_heads,
  151. head_size);
  152. }
  153. });
  154. }
  155. /*
  156. Batched version of rotary embedding, pack multiple LoRAs together
  157. and process in batched manner.
  158. */
  159. void batched_rotary_embedding(
  160. torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
  161. torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
  162. torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
  163. int head_size,
  164. torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
  165. bool is_neox,
  166. int rot_dim,
  167. torch::Tensor& cos_sin_cache_offsets // [num_tokens]
  168. ) {
  169. int64_t num_tokens = cos_sin_cache_offsets.size(0);
  170. int num_heads = query.size(-1) / head_size;
  171. int num_kv_heads = key.size(-1) / head_size;
  172. int64_t query_stride = query.stride(-2);
  173. int64_t key_stride = key.stride(-2);
  174. dim3 grid(num_tokens);
  175. dim3 block(std::min(num_heads * rot_dim / 2, 512));
  176. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  177. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  178. APHRODITE_DISPATCH_FLOATING_TYPES(
  179. query.scalar_type(),
  180. "rotary_embedding",
  181. [&] {
  182. if (is_neox) {
  183. aphrodite::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
  184. positions.data_ptr<int64_t>(),
  185. query.data_ptr<scalar_t>(),
  186. key.data_ptr<scalar_t>(),
  187. cos_sin_cache.data_ptr<scalar_t>(),
  188. cos_sin_cache_offsets.data_ptr<int64_t>(),
  189. rot_dim,
  190. query_stride,
  191. key_stride,
  192. num_heads,
  193. num_kv_heads,
  194. head_size);
  195. } else {
  196. aphrodite::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
  197. positions.data_ptr<int64_t>(),
  198. query.data_ptr<scalar_t>(),
  199. key.data_ptr<scalar_t>(),
  200. cos_sin_cache.data_ptr<scalar_t>(),
  201. cos_sin_cache_offsets.data_ptr<int64_t>(),
  202. rot_dim,
  203. query_stride,
  204. key_stride,
  205. num_heads,
  206. num_kv_heads,
  207. head_size);
  208. }
  209. });
  210. }