|
@@ -1549,7 +1549,7 @@ inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_
|
|
|
if (2 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q = rotary_embedding_transform(q, coef);
|
|
|
}
|
|
|
|
|
@@ -1558,7 +1558,7 @@ inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int
|
|
|
if (2 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q = rotary_embedding_transform(q, coef);
|
|
|
k = rotary_embedding_transform(k, coef);
|
|
|
}
|
|
@@ -1570,9 +1570,9 @@ inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_
|
|
|
}
|
|
|
|
|
|
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
|
|
|
- const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q_.x = rotary_embedding_transform(q_.x, coef0);
|
|
|
- const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
|
|
|
q_.y = rotary_embedding_transform(q_.y, coef1);
|
|
|
}
|
|
|
|
|
@@ -1584,10 +1584,10 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int
|
|
|
|
|
|
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
|
|
|
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
|
|
|
- const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q_.x = rotary_embedding_transform(q_.x, coef0);
|
|
|
k_.x = rotary_embedding_transform(k_.x, coef0);
|
|
|
- const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
|
|
|
q_.y = rotary_embedding_transform(q_.y, coef1);
|
|
|
k_.y = rotary_embedding_transform(k_.y, coef1);
|
|
|
}
|
|
@@ -1597,7 +1597,7 @@ inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embe
|
|
|
if (2 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q = rotary_embedding_transform(q, coef);
|
|
|
}
|
|
|
|
|
@@ -1606,7 +1606,7 @@ inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid,
|
|
|
if (2 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q = rotary_embedding_transform(q, coef);
|
|
|
k = rotary_embedding_transform(k, coef);
|
|
|
}
|
|
@@ -1616,9 +1616,9 @@ inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_d
|
|
|
if (4 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q.x = rotary_embedding_transform(q.x, coef0);
|
|
|
- const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
|
|
|
q.y = rotary_embedding_transform(q.y, coef1);
|
|
|
}
|
|
|
|
|
@@ -1627,10 +1627,10 @@ inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int r
|
|
|
if (4 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q.x = rotary_embedding_transform(q.x, coef0);
|
|
|
k.x = rotary_embedding_transform(k.x, coef0);
|
|
|
- const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
|
|
|
q.y = rotary_embedding_transform(q.y, coef1);
|
|
|
k.y = rotary_embedding_transform(k.y, coef1);
|
|
|
}
|
|
@@ -1640,13 +1640,13 @@ inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_d
|
|
|
if (8 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q.x = rotary_embedding_transform(q.x, coef0);
|
|
|
- const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
|
|
|
q.y = rotary_embedding_transform(q.y, coef1);
|
|
|
- const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
|
|
|
q.z = rotary_embedding_transform(q.z, coef2);
|
|
|
- const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
|
|
|
q.w = rotary_embedding_transform(q.w, coef3);
|
|
|
}
|
|
|
|
|
@@ -1655,16 +1655,16 @@ inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int r
|
|
|
if (8 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q.x = rotary_embedding_transform(q.x, coef0);
|
|
|
k.x = rotary_embedding_transform(k.x, coef0);
|
|
|
- const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
|
|
|
q.y = rotary_embedding_transform(q.y, coef1);
|
|
|
k.y = rotary_embedding_transform(k.y, coef1);
|
|
|
- const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
|
|
|
q.z = rotary_embedding_transform(q.z, coef2);
|
|
|
k.z = rotary_embedding_transform(k.z, coef2);
|
|
|
- const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
|
|
|
q.w = rotary_embedding_transform(q.w, coef3);
|
|
|
k.w = rotary_embedding_transform(k.w, coef3);
|
|
|
}
|
|
@@ -1675,7 +1675,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int ro
|
|
|
if (2 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q = rotary_embedding_transform(q, coef);
|
|
|
}
|
|
|
|
|
@@ -1684,7 +1684,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162&
|
|
|
if (2 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q = rotary_embedding_transform(q, coef);
|
|
|
k = rotary_embedding_transform(k, coef);
|
|
|
}
|
|
@@ -1694,9 +1694,9 @@ inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embe
|
|
|
if (4 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q.x = rotary_embedding_transform(q.x, coef0);
|
|
|
- const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
|
|
|
q.y = rotary_embedding_transform(q.y, coef1);
|
|
|
}
|
|
|
|
|
@@ -1705,10 +1705,10 @@ inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid,
|
|
|
if (4 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q.x = rotary_embedding_transform(q.x, coef0);
|
|
|
k.x = rotary_embedding_transform(k.x, coef0);
|
|
|
- const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
|
|
|
q.y = rotary_embedding_transform(q.y, coef1);
|
|
|
k.y = rotary_embedding_transform(k.y, coef1);
|
|
|
}
|
|
@@ -1718,13 +1718,13 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embe
|
|
|
if (8 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q.x = rotary_embedding_transform(q.x, coef0);
|
|
|
- const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
|
|
|
q.y = rotary_embedding_transform(q.y, coef1);
|
|
|
- const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
|
|
|
q.z = rotary_embedding_transform(q.z, coef2);
|
|
|
- const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
|
|
|
q.w = rotary_embedding_transform(q.w, coef3);
|
|
|
}
|
|
|
|
|
@@ -1733,16 +1733,16 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid,
|
|
|
if (8 * tid >= rot_embed_dim) {
|
|
|
return;
|
|
|
}
|
|
|
- const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
|
|
|
q.x = rotary_embedding_transform(q.x, coef0);
|
|
|
k.x = rotary_embedding_transform(k.x, coef0);
|
|
|
- const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
|
|
|
q.y = rotary_embedding_transform(q.y, coef1);
|
|
|
k.y = rotary_embedding_transform(k.y, coef1);
|
|
|
- const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
|
|
|
q.z = rotary_embedding_transform(q.z, coef2);
|
|
|
k.z = rotary_embedding_transform(k.z, coef2);
|
|
|
- const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
|
|
|
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
|
|
|
q.w = rotary_embedding_transform(q.w, coef3);
|
|
|
k.w = rotary_embedding_transform(k.w, coef3);
|
|
|
}
|