Parcourir la source

fix: pos encodings for CPU

AlpinDale il y a 7 mois
Parent
commit
ba02fb65c9
1 fichiers modifiés avec 52 ajouts et 50 suppressions
  1. 52 50
      kernels/cpu/pos_encoding.cpp

+ 52 - 50
kernels/cpu/pos_encoding.cpp

@@ -1,3 +1,4 @@
+
 #include "cpu_types.hpp"
 #include "cpu_types.hpp"
 
 
 namespace {
 namespace {
@@ -20,73 +21,74 @@ void rotary_embedding_impl(
   constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
   constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
 
 
   const int embed_dim = rot_dim / 2;
   const int embed_dim = rot_dim / 2;
-  TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
+  bool flag = (embed_dim % VEC_ELEM_NUM == 0);
+  const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;
 
 
-#pragma omp parallel for
-  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
-    int64_t pos = positions[token_idx];
-    const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
-
-    for (int i = 0; i < num_heads; ++i) {
-      const int head_idx = i;
-      const int64_t token_head =
-          token_idx * query_stride + head_idx * head_size;
-      for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
-        const int rot_offset = j;
-        const int x_index = rot_offset;
-        const int y_index = embed_dim + rot_offset;
+  auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr,
+                          scalar_t* qk) {
+    int j = 0;
+    for (; j < loop_upper; j += VEC_ELEM_NUM) {
+      const int rot_offset = j;
+      const int x_index = rot_offset;
+      const int y_index = embed_dim + rot_offset;
 
 
-        const int64_t out_x = token_head + x_index;
-        const int64_t out_y = token_head + y_index;
+      const int64_t out_x = token_head + x_index;
+      const int64_t out_y = token_head + y_index;
 
 
-        const scalar_vec_t cos(cache_ptr + x_index);
-        const scalar_vec_t sin(cache_ptr + y_index);
+      const scalar_vec_t cos(cache_ptr + x_index);
+      const scalar_vec_t sin(cache_ptr + y_index);
 
 
-        const scalar_vec_t q_x(query + out_x);
-        const scalar_vec_t q_y(query + out_y);
+      const scalar_vec_t q_x(qk + out_x);
+      const scalar_vec_t q_y(qk + out_y);
 
 
-        vec_op::FP32Vec8 fp32_cos(cos);
-        vec_op::FP32Vec8 fp32_sin(sin);
+      vec_op::FP32Vec8 fp32_cos(cos);
+      vec_op::FP32Vec8 fp32_sin(sin);
 
 
-        vec_op::FP32Vec8 fp32_q_x(q_x);
-        vec_op::FP32Vec8 fp32_q_y(q_y);
+      vec_op::FP32Vec8 fp32_q_x(q_x);
+      vec_op::FP32Vec8 fp32_q_y(q_y);
 
 
-        auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
-        scalar_vec_t(out1).save(query + out_x);
+      auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
+      scalar_vec_t(out1).save(qk + out_x);
 
 
-        auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
-        scalar_vec_t(out2).save(query + out_y);
-      }
+      auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
+      scalar_vec_t(out2).save(qk + out_y);
     }
     }
-
-    for (int i = 0; i < num_kv_heads; ++i) {
-      const int head_idx = i;
-      const int64_t token_head = token_idx * key_stride + head_idx * head_size;
-      for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
-        const int rot_offset = j;
-        const int x_index = rot_offset;
-        const int y_index = embed_dim + rot_offset;
+    if (!flag) {
+      for (; j < embed_dim; ++j) {
+        const int x_index = j;
+        const int y_index = embed_dim + j;
 
 
         const int64_t out_x = token_head + x_index;
         const int64_t out_x = token_head + x_index;
         const int64_t out_y = token_head + y_index;
         const int64_t out_y = token_head + y_index;
 
 
-        const scalar_vec_t cos(cache_ptr + x_index);
-        const scalar_vec_t sin(cache_ptr + y_index);
+        const float fp32_cos = cache_ptr[x_index];
+        const float fp32_sin = cache_ptr[y_index];
 
 
-        const scalar_vec_t k_x(key + out_x);
-        const scalar_vec_t k_y(key + out_y);
+        const float fp32_q_x = qk[out_x];
+        const float fp32_q_y = qk[out_y];
 
 
-        vec_op::FP32Vec8 fp32_cos(cos);
-        vec_op::FP32Vec8 fp32_sin(sin);
+        qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
+        qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
+      }
+    }
+  };
 
 
-        vec_op::FP32Vec8 fp32_k_x(k_x);
-        vec_op::FP32Vec8 fp32_k_y(k_y);
+#pragma omp parallel for
+  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
+    int64_t pos = positions[token_idx];
+    const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
 
 
-        auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
-        scalar_vec_t(out1).save(key + out_x);
-        auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
-        scalar_vec_t(out2).save(key + out_y);
-      }
+    for (int i = 0; i < num_heads; ++i) {
+      const int head_idx = i;
+      const int64_t token_head =
+          token_idx * query_stride + head_idx * head_size;
+      compute_loop(token_head, cache_ptr, query);
+    }
+
+    for (int i = 0; i < num_kv_heads; ++i) {
+      const int head_idx = i;
+      const int64_t token_head = token_idx * key_stride + head_idx * head_size;
+      compute_loop(token_head, cache_ptr, key);
     }
     }
   }
   }
 }
 }