pos_encoding.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. #include "cpu_types.hpp"
  2. namespace {
  3. template <typename scalar_t>
  4. void rotary_embedding_impl(
  5. const int64_t* __restrict__ positions, // [batch_size, seq_len] or
  6. // [num_tokens]
  7. scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
  8. /// head_size] or [num_tokens, num_heads,
  9. /// head_size]
  10. scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
  11. // head_size] or [num_tokens, num_kv_heads,
  12. // head_size]
  13. const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
  14. // 2]
  15. const int rot_dim, const int64_t query_stride, const int64_t key_stride,
  16. const int num_heads, const int num_kv_heads, const int head_size,
  17. const int num_tokens) {
  18. using scalar_vec_t = vec_op::vec_t<scalar_t>;
  19. constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
  20. const int embed_dim = rot_dim / 2;
  21. TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
  22. #pragma omp parallel for
  23. for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
  24. int64_t pos = positions[token_idx];
  25. const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
  26. for (int i = 0; i < num_heads; ++i) {
  27. const int head_idx = i;
  28. const int64_t token_head =
  29. token_idx * query_stride + head_idx * head_size;
  30. for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
  31. const int rot_offset = j;
  32. const int x_index = rot_offset;
  33. const int y_index = embed_dim + rot_offset;
  34. const int64_t out_x = token_head + x_index;
  35. const int64_t out_y = token_head + y_index;
  36. const scalar_vec_t cos(cache_ptr + x_index);
  37. const scalar_vec_t sin(cache_ptr + y_index);
  38. const scalar_vec_t q_x(query + out_x);
  39. const scalar_vec_t q_y(query + out_y);
  40. vec_op::FP32Vec8 fp32_cos(cos);
  41. vec_op::FP32Vec8 fp32_sin(sin);
  42. vec_op::FP32Vec8 fp32_q_x(q_x);
  43. vec_op::FP32Vec8 fp32_q_y(q_y);
  44. auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
  45. scalar_vec_t(out1).save(query + out_x);
  46. auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
  47. scalar_vec_t(out2).save(query + out_y);
  48. }
  49. }
  50. for (int i = 0; i < num_kv_heads; ++i) {
  51. const int head_idx = i;
  52. const int64_t token_head = token_idx * key_stride + head_idx * head_size;
  53. for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
  54. const int rot_offset = j;
  55. const int x_index = rot_offset;
  56. const int y_index = embed_dim + rot_offset;
  57. const int64_t out_x = token_head + x_index;
  58. const int64_t out_y = token_head + y_index;
  59. const scalar_vec_t cos(cache_ptr + x_index);
  60. const scalar_vec_t sin(cache_ptr + y_index);
  61. const scalar_vec_t k_x(key + out_x);
  62. const scalar_vec_t k_y(key + out_y);
  63. vec_op::FP32Vec8 fp32_cos(cos);
  64. vec_op::FP32Vec8 fp32_sin(sin);
  65. vec_op::FP32Vec8 fp32_k_x(k_x);
  66. vec_op::FP32Vec8 fp32_k_y(k_y);
  67. auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
  68. scalar_vec_t(out1).save(key + out_x);
  69. auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
  70. scalar_vec_t(out2).save(key + out_y);
  71. }
  72. }
  73. }
  74. }
  75. template <typename scalar_t>
  76. void rotary_embedding_gptj_impl(
  77. const int64_t* __restrict__ positions, // [batch_size, seq_len] or
  78. // [num_tokens]
  79. scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
  80. /// head_size] or [num_tokens, num_heads,
  81. /// head_size]
  82. scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
  83. // head_size] or [num_tokens, num_kv_heads,
  84. // head_size]
  85. const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
  86. // 2]
  87. const int rot_dim, const int64_t query_stride, const int64_t key_stride,
  88. const int num_heads, const int num_kv_heads, const int head_size,
  89. const int num_tokens) {
  90. const int embed_dim = rot_dim / 2;
  91. #pragma omp parallel for collapse(2)
  92. for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
  93. for (int i = 0; i < num_heads; ++i) {
  94. int64_t pos = positions[token_idx];
  95. const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
  96. const scalar_t* cos_cache_ptr = cache_ptr;
  97. const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
  98. const int head_idx = i;
  99. const int64_t token_head =
  100. token_idx * query_stride + head_idx * head_size;
  101. scalar_t* head_query = token_head + query;
  102. for (int j = 0; j < embed_dim; j += 1) {
  103. const int rot_offset = j;
  104. const int x_index = 2 * rot_offset;
  105. const int y_index = 2 * rot_offset + 1;
  106. const float cos = cos_cache_ptr[rot_offset];
  107. const float sin = sin_cache_ptr[rot_offset];
  108. const float x = head_query[x_index];
  109. const float y = head_query[y_index];
  110. head_query[x_index] = x * cos - y * sin;
  111. head_query[y_index] = y * cos + x * sin;
  112. }
  113. }
  114. }
  115. #pragma omp parallel for collapse(2)
  116. for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
  117. for (int i = 0; i < num_kv_heads; ++i) {
  118. int64_t pos = positions[token_idx];
  119. const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
  120. const scalar_t* cos_cache_ptr = cache_ptr;
  121. const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
  122. const int head_idx = i;
  123. const int64_t token_head = token_idx * key_stride + head_idx * head_size;
  124. scalar_t* head_key = key + token_head;
  125. for (int j = 0; j < embed_dim; j += 1) {
  126. const int rot_offset = j;
  127. const int x_index = 2 * rot_offset;
  128. const int y_index = 2 * rot_offset + 1;
  129. const float cos = cos_cache_ptr[rot_offset];
  130. const float sin = sin_cache_ptr[rot_offset];
  131. const float x = head_key[x_index];
  132. const float y = head_key[y_index];
  133. head_key[x_index] = x * cos - y * sin;
  134. head_key[y_index] = y * cos + x * sin;
  135. }
  136. }
  137. }
  138. }
  139. }; // namespace
  140. void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
  141. torch::Tensor& key, int head_size,
  142. torch::Tensor& cos_sin_cache, bool is_neox) {
  143. int num_tokens = query.numel() / query.size(-1);
  144. int rot_dim = cos_sin_cache.size(1);
  145. int num_heads = query.size(-1) / head_size;
  146. int num_kv_heads = key.size(-1) / head_size;
  147. int64_t key_stride = key.stride(-2);
  148. int64_t query_stride = query.stride(-2);
  149. APHRODITE_DISPATCH_FLOATING_TYPES(
  150. query.scalar_type(), "rotary_embedding_impl", [&] {
  151. CPU_KERNEL_GUARD_IN(rotary_embedding_impl)
  152. if (is_neox) {
  153. rotary_embedding_impl(
  154. positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
  155. key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
  156. rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
  157. head_size, num_tokens);
  158. } else {
  159. rotary_embedding_gptj_impl(
  160. positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
  161. key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
  162. rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
  163. head_size, num_tokens);
  164. }
  165. CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)
  166. });
  167. }