|
@@ -355,65 +355,161 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
-template <bool Is_2_sources=false, bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
|
|
- typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
|
|
+template <bool Is_even_K=true,
|
|
|
+ typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
|
|
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
|
|
-inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S0,
|
|
|
- Tensor<Engine0, Layout0> const &S1,
|
|
|
+inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
|
|
|
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
|
|
Tensor<Engine3, Layout3> const &predicate_K,
|
|
|
- const int max_MN=0, const int row_idx_switch=0) {
|
|
|
- CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{});
|
|
|
+ const int max_MN=0, const int min_MN=0) {
|
|
|
+ CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
|
|
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
|
|
- CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA
|
|
|
- CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M
|
|
|
- CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K
|
|
|
- // There's no case where !Clear_OOB_K && Clear_OOB_MN
|
|
|
- static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
|
|
|
- // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); }
|
|
|
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); }
|
|
|
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
|
|
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
|
|
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
|
|
|
#pragma unroll
|
|
|
- for (int m = 0; m < size<1>(S0); ++m) {
|
|
|
- auto &S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1;
|
|
|
- if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
|
|
+ for (int m = 0; m < size<1>(S); ++m) {
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
|
|
|
+ if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
|
|
|
#pragma unroll
|
|
|
- for (int k = 0; k < size<2>(S0); ++k) {
|
|
|
+ for (int k = 0; k < size<2>(S); ++k) {
|
|
|
if (Is_even_K || predicate_K(k)) {
|
|
|
- cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
|
|
+ cute::copy(S(_, m, k), D(_, m, k));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
+
|
|
|
+template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
|
|
+ typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
|
|
+ typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
|
|
+inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
|
|
|
+ Tensor<Engine1, Layout1> &D,
|
|
|
+ Tensor<Engine2, Layout2> const &Cos,
|
|
|
+ Tensor<Engine2, Layout2> const &Sin,
|
|
|
+ Tensor<Engine3, Layout3> const &identity_MN,
|
|
|
+ const int max_MN, const int min_MN,
|
|
|
+ const int dim, const int rotary_dim) {
|
|
|
+ CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
|
|
+ CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
|
|
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
|
|
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
|
|
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
|
|
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
|
|
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
|
|
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
|
|
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
|
|
+ CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
|
|
|
+ static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
|
|
|
+ static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
|
|
+ Tensor rCos = make_fragment_like(Cos);
|
|
|
+ Tensor rSin = make_fragment_like(Sin);
|
|
|
+ Tensor rS = make_fragment_like(S);
|
|
|
+ #pragma unroll
|
|
|
+ for (int m = 0; m < size<1>(S); ++m) {
|
|
|
+ if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
|
|
+ #pragma unroll
|
|
|
+ for (int k = 0; k < size<2>(S); ++k) {
|
|
|
+ if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
|
|
+ cute::copy(S(_, m, k), rS(_, m, k));
|
|
|
+ if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
|
|
+ cute::copy(Cos(_, m, k), rCos(_, m, k));
|
|
|
+ cute::copy(Sin(_, m, k), rSin(_, m, k));
|
|
|
+ Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
|
|
+ Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
|
|
+ Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < size<0>(rS) / 2; ++i) {
|
|
|
+ float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
|
|
|
+ float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
|
|
|
+ S_fp32(2 * i) = real;
|
|
|
+ S_fp32(2 * i + 1) = imag;
|
|
|
+ }
|
|
|
+ // Idk but I need to copy for the convert_type to work
|
|
|
+ Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
|
|
+ cute::copy(S_fp32, S_fp32_copy);
|
|
|
+ using T = typename Engine0::value_type;
|
|
|
+ Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
|
|
+ cute::copy(S_og_type, rS(_, m, k));
|
|
|
+ }
|
|
|
+ cute::copy(rS(_, m, k), D(_, m, k));
|
|
|
} else if (Clear_OOB_K) {
|
|
|
cute::clear(D(_, m, k));
|
|
|
}
|
|
|
}
|
|
|
- } else if (Clear_OOB_MN) {
|
|
|
- cute::clear(D(_, m, _));
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
-template <bool Is_even_K=true,
|
|
|
+template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
|
|
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
|
|
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
|
|
-inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
|
|
|
- Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
|
|
- Tensor<Engine3, Layout3> const &predicate_K,
|
|
|
- const int max_MN=0, const int min_MN=0) {
|
|
|
+inline __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
|
|
|
+ Tensor<Engine1, Layout1> &D,
|
|
|
+ Tensor<Engine2, Layout2> const &Cos,
|
|
|
+ Tensor<Engine2, Layout2> const &Sin,
|
|
|
+ Tensor<Engine3, Layout3> const &identity_MN,
|
|
|
+ const int max_MN, const int min_MN,
|
|
|
+ const int dim, const int rotary_dim) {
|
|
|
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
|
|
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
|
|
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
|
|
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
|
|
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
|
|
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
|
|
|
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
|
|
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
|
|
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
|
|
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
|
|
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
|
|
|
+ CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
|
|
|
+ static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
|
|
+ Tensor rCos = make_fragment_like(Cos);
|
|
|
+ Tensor rSin = make_fragment_like(Sin);
|
|
|
+ Tensor rS = make_fragment_like(S);
|
|
|
+ Tensor rS_other = make_fragment_like(rS(_, 0, 0));
|
|
|
#pragma unroll
|
|
|
for (int m = 0; m < size<1>(S); ++m) {
|
|
|
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
|
|
|
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
|
|
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
|
|
|
#pragma unroll
|
|
|
for (int k = 0; k < size<2>(S); ++k) {
|
|
|
- if (Is_even_K || predicate_K(k)) {
|
|
|
- cute::copy(S(_, m, k), D(_, m, k));
|
|
|
+ if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
|
|
+ cute::copy(S(_, m, k), rS(_, m, k));
|
|
|
+ if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
|
|
+ const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
|
|
|
+ Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
|
|
|
+ cute::copy(gS_other, rS_other);
|
|
|
+ // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
|
|
|
+ Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
|
|
|
+ Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
|
|
|
+ cute::copy(gCos, rCos(_, m, k));
|
|
|
+ cute::copy(gSin, rSin(_, m, k));
|
|
|
+ // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
|
|
|
+ Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
|
|
+ Tensor S_other_fp32 = convert_type<float>(rS_other);
|
|
|
+ Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
|
|
+ Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < size<0>(rS); ++i) {
|
|
|
+ S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
|
|
|
+ }
|
|
|
+ // Idk but I need to copy for the convert_type to work
|
|
|
+ Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
|
|
+ cute::copy(S_fp32, S_fp32_copy);
|
|
|
+ using T = typename Engine0::value_type;
|
|
|
+ Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
|
|
+ cute::copy(S_og_type, rS(_, m, k));
|
|
|
+ // if (cute::thread0()) { print_tensor(rS(_, m, k)); }
|
|
|
+ }
|
|
|
+ cute::copy(rS(_, m, k), D(_, m, k));
|
|
|
+ } else if (Clear_OOB_K) {
|
|
|
+ cute::clear(D(_, m, k));
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -422,4 +518,4 @@ inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
-} // namespace flash
|
|
|
+} // namespace flash
|