pos_encoding_kernels.cu 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. namespace aphrodite {
  4. template<typename scalar_t>
  5. __global__ void rotary_embedding_neox_kernel(
  6. const int64_t* __restrict__ positions, // [num_tokens]
  7. scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
  8. scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  9. const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
  10. const int rot_dim,
  11. const int stride,
  12. const int num_heads,
  13. const int head_size) {
  14. // Each thread block is responsible for one token.
  15. const int token_idx = blockIdx.x;
  16. int64_t pos = positions[token_idx];
  17. const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
  18. const int embed_dim = rot_dim / 2;
  19. const int n = num_heads * embed_dim;
  20. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  21. const int head_idx = i / embed_dim;
  22. const int token_head = token_idx * stride + head_idx * head_size;
  23. const int rot_offset = i % embed_dim;
  24. const int x_index = rot_offset;
  25. const int y_index = embed_dim + rot_offset;
  26. const int out_x = token_idx * stride + head_idx * head_size + x_index;
  27. const int out_y = token_idx * stride + head_idx * head_size + y_index;
  28. const scalar_t cos = __ldg(cache_ptr + x_index);
  29. const scalar_t sin = __ldg(cache_ptr + y_index);
  30. const scalar_t q_x = query[token_head + x_index];
  31. const scalar_t q_y = query[token_head + y_index];
  32. query[out_x] = q_x * cos - q_y * sin;
  33. query[out_y] = q_y * cos + q_x * sin;
  34. const scalar_t k_x = key[token_head + x_index];
  35. const scalar_t k_y = key[token_head + y_index];
  36. key[out_x] = k_x * cos - k_y * sin;
  37. key[out_y] = k_y * cos + k_x * sin;
  38. }
  39. }
  40. } // namespace aphrodite
  41. void rotary_embedding_neox(
  42. torch::Tensor& positions, // [num_tokens]
  43. torch::Tensor& query, // [num_tokens, num_heads * head_size]
  44. torch::Tensor& key, // [num_tokens, num_heads * head_size]
  45. int head_size,
  46. torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
  47. {
  48. int num_tokens = query.size(0);
  49. int rot_dim = cos_sin_cache.size(1);
  50. int num_heads = query.size(1) / head_size;
  51. int stride = query.stride(0);
  52. TORCH_CHECK(stride == key.stride(0));
  53. dim3 grid(num_tokens);
  54. dim3 block(std::min(num_heads * rot_dim / 2, 512));
  55. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  56. AT_DISPATCH_FLOATING_TYPES_AND2(
  57. at::ScalarType::Half,
  58. at::ScalarType::BFloat16,
  59. query.scalar_type(),
  60. "rotary_embedding_neox",
  61. [&] {
  62. aphrodite::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
  63. positions.data_ptr<int64_t>(),
  64. query.data_ptr<scalar_t>(),
  65. key.data_ptr<scalar_t>(),
  66. cos_sin_cache.data_ptr<scalar_t>(),
  67. rot_dim,
  68. stride,
  69. num_heads,
  70. head_size);
  71. });
  72. }