Просмотр исходного кода

[FT] rotary_cos/sin should have shape (dim) instead of (seqlen, dim)

Tri Dao 1 год назад
Родитель
Сommit
3a9bfd076f

+ 34 - 34
csrc/ft_attention/decoder_masked_multihead_attention_utils.h

@@ -1549,7 +1549,7 @@ inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_
     if (2 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
     q               = rotary_embedding_transform(q, coef);
 }
 
@@ -1558,7 +1558,7 @@ inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int
     if (2 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
     q               = rotary_embedding_transform(q, coef);
     k               = rotary_embedding_transform(k, coef);
 }
@@ -1570,9 +1570,9 @@ inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_
     }
 
     Float4_&   q_    = *reinterpret_cast<Float4_*>(&q);
-    const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
     q_.x             = rotary_embedding_transform(q_.x, coef0);
-    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
     q_.y             = rotary_embedding_transform(q_.y, coef1);
 }
 
@@ -1584,10 +1584,10 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int
 
     Float4_&   q_    = *reinterpret_cast<Float4_*>(&q);
     Float4_&   k_    = *reinterpret_cast<Float4_*>(&k);
-    const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
     q_.x             = rotary_embedding_transform(q_.x, coef0);
     k_.x             = rotary_embedding_transform(k_.x, coef0);
-    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
     q_.y             = rotary_embedding_transform(q_.y, coef1);
     k_.y             = rotary_embedding_transform(k_.y, coef1);
 }
@@ -1597,7 +1597,7 @@ inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embe
     if (2 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
     q               = rotary_embedding_transform(q, coef);
 }
 
@@ -1606,7 +1606,7 @@ inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid,
     if (2 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
     q               = rotary_embedding_transform(q, coef);
     k               = rotary_embedding_transform(k, coef);
 }
@@ -1616,9 +1616,9 @@ inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_d
     if (4 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
     q.x              = rotary_embedding_transform(q.x, coef0);
-    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
     q.y              = rotary_embedding_transform(q.y, coef1);
 }
 
@@ -1627,10 +1627,10 @@ inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int r
     if (4 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
     q.x              = rotary_embedding_transform(q.x, coef0);
     k.x              = rotary_embedding_transform(k.x, coef0);
-    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
     q.y              = rotary_embedding_transform(q.y, coef1);
     k.y              = rotary_embedding_transform(k.y, coef1);
 }
@@ -1640,13 +1640,13 @@ inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_d
     if (8 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
     q.x              = rotary_embedding_transform(q.x, coef0);
-    const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
     q.y              = rotary_embedding_transform(q.y, coef1);
-    const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
     q.z              = rotary_embedding_transform(q.z, coef2);
-    const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
     q.w              = rotary_embedding_transform(q.w, coef3);
 }
 
@@ -1655,16 +1655,16 @@ inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int r
     if (8 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
     q.x              = rotary_embedding_transform(q.x, coef0);
     k.x              = rotary_embedding_transform(k.x, coef0);
-    const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
     q.y              = rotary_embedding_transform(q.y, coef1);
     k.y              = rotary_embedding_transform(k.y, coef1);
-    const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
     q.z              = rotary_embedding_transform(q.z, coef2);
     k.z              = rotary_embedding_transform(k.z, coef2);
-    const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
     q.w              = rotary_embedding_transform(q.w, coef3);
     k.w              = rotary_embedding_transform(k.w, coef3);
 }
@@ -1675,7 +1675,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int ro
     if (2 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
     q               = rotary_embedding_transform(q, coef);
 }
 
@@ -1684,7 +1684,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162&
     if (2 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
     q               = rotary_embedding_transform(q, coef);
     k               = rotary_embedding_transform(k, coef);
 }
@@ -1694,9 +1694,9 @@ inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embe
     if (4 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
     q.x              = rotary_embedding_transform(q.x, coef0);
-    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
     q.y              = rotary_embedding_transform(q.y, coef1);
 }
 
@@ -1705,10 +1705,10 @@ inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid,
     if (4 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
     q.x              = rotary_embedding_transform(q.x, coef0);
     k.x              = rotary_embedding_transform(k.x, coef0);
-    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
     q.y              = rotary_embedding_transform(q.y, coef1);
     k.y              = rotary_embedding_transform(k.y, coef1);
 }
@@ -1718,13 +1718,13 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embe
     if (8 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
     q.x              = rotary_embedding_transform(q.x, coef0);
-    const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
     q.y              = rotary_embedding_transform(q.y, coef1);
-    const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
     q.z              = rotary_embedding_transform(q.z, coef2);
-    const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
     q.w              = rotary_embedding_transform(q.w, coef3);
 }
 
@@ -1733,16 +1733,16 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid,
     if (8 * tid >= rot_embed_dim) {
         return;
     }
-    const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
     q.x              = rotary_embedding_transform(q.x, coef0);
     k.x              = rotary_embedding_transform(k.x, coef0);
-    const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
     q.y              = rotary_embedding_transform(q.y, coef1);
     k.y              = rotary_embedding_transform(k.y, coef1);
-    const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
     q.z              = rotary_embedding_transform(q.z, coef2);
     k.z              = rotary_embedding_transform(k.z, coef2);
-    const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
+    const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
     q.w              = rotary_embedding_transform(q.w, coef3);
     k.w              = rotary_embedding_transform(k.w, coef3);
 }

+ 3 - 4
csrc/ft_attention/ft_attention.cpp

@@ -160,16 +160,15 @@ torch::Tensor single_query_attention(const torch::Tensor q,
     if (rotary_cos_.has_value()) {
         auto rotary_cos = rotary_cos_.value();
         CHECK_DEVICE(rotary_cos);
-        int rotary_seqlen = rotary_cos.size(0);
-        rotary_embedding_dim = rotary_cos.size(1) * 2;
-        CHECK_SHAPE(rotary_cos, rotary_seqlen, rotary_embedding_dim / 2);
+        rotary_embedding_dim = rotary_cos.size(0) * 2;
+        CHECK_SHAPE(rotary_cos, rotary_embedding_dim / 2);
         CHECK_CONTIGUOUS(rotary_cos);
         TORCH_CHECK(rotary_cos.scalar_type() == input_type);
 
         TORCH_CHECK(rotary_sin_.has_value());
         auto rotary_sin = rotary_sin_.value();
         CHECK_DEVICE(rotary_sin);
-        CHECK_SHAPE(rotary_cos, rotary_seqlen, rotary_embedding_dim / 2);
+        CHECK_SHAPE(rotary_cos, rotary_embedding_dim / 2);
         CHECK_CONTIGUOUS(rotary_sin);
         TORCH_CHECK(rotary_sin.scalar_type() == input_type);
     }