|
@@ -657,10 +657,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
|
|
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
|
|
|
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
|
|
- const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
|
|
|
- + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
|
|
|
- const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
|
|
|
- + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
|
|
|
|
|
|
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
|
|
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
@@ -672,18 +668,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
|
|
|
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
make_stride(params.v_row_stride, _1{}));
|
|
|
- // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
|
|
|
- // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
|
|
|
- // This maps to accessing the first 64 rows of knew_ptr.
|
|
|
- Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
|
|
|
- + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
|
|
|
- Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
- make_stride(params.knew_row_stride, _1{}));
|
|
|
- // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
|
|
|
- Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
|
|
|
- + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
|
|
|
- Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
- make_stride(params.vnew_row_stride, _1{}));
|
|
|
|
|
|
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
|
|
typename Kernel_traits::SmemLayoutQ{});
|
|
@@ -698,10 +682,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
|
|
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
|
|
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
|
|
|
- Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
|
|
|
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
|
|
|
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
|
|
|
- Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
|
|
|
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
|
|
|
|
|
|
typename Kernel_traits::TiledMma tiled_mma;
|
|
@@ -762,6 +744,49 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
|
|
|
// Prologue
|
|
|
|
|
|
+ if constexpr (Append_KV) {
|
|
|
+ // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
|
|
|
+ // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
|
|
|
+ // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
|
|
|
+ const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
|
|
|
+ + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
|
|
|
+ const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
|
|
|
+ + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
|
|
|
+ // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
|
|
|
+ // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
|
|
|
+ // This maps to accessing the first 64 rows of knew_ptr.
|
|
|
+ Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
|
|
|
+ + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
+ make_stride(params.knew_row_stride, _1{}));
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
|
|
|
+ Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
|
|
|
+ + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
+ make_stride(params.vnew_row_stride, _1{}));
|
|
|
+ Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
|
|
|
+ Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
|
|
|
+
|
|
|
+ const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
|
|
|
+ for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
|
|
|
+ flash::copy_w_min_idx<Is_even_K>(
|
|
|
+ tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
+ );
|
|
|
+ flash::copy_w_min_idx<Is_even_K>(
|
|
|
+ tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
+ );
|
|
|
+ tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
|
|
+ tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
|
|
|
+ tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
|
|
+ tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+ if (n_block_max > n_block_copy_min) {
|
|
|
+ tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride;
|
|
|
+ tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
Tensor tQrQ = make_fragment_like(tQgQ);
|
|
|
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
|
|
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
|
@@ -769,10 +794,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
|
|
|
int n_block = n_block_max - 1;
|
|
|
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
|
|
- flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K>(
|
|
|
- gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV,
|
|
|
- binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
- );
|
|
|
+ flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
|
|
+ binfo.actual_seqlen_k - n_block * kBlockN);
|
|
|
cute::cp_async_fence();
|
|
|
|
|
|
// flash::cp_async_wait<0>();
|
|
@@ -800,32 +823,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
flash::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
|
|
|
- if constexpr (Append_KV) {
|
|
|
- // if (cute::thread0()) { print(tKgK); }
|
|
|
- // if (cute::thread0()) { print(tKsK); }
|
|
|
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
|
|
|
- if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
|
|
|
- flash::copy_w_min_idx<Is_even_K>(
|
|
|
- tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
- );
|
|
|
- }
|
|
|
- // __syncthreads();
|
|
|
- // if (cute::thread0()) { print(tKgK); }
|
|
|
- // __syncthreads();
|
|
|
- }
|
|
|
-
|
|
|
// Advance gV
|
|
|
if (masking_step > 0) {
|
|
|
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
|
|
- if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
|
|
|
- flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
|
|
|
- gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
- );
|
|
|
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
|
|
} else {
|
|
|
// Clear the smem tiles to account for predicated off loads
|
|
|
- flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
|
|
- gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV,
|
|
|
- binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
|
|
+ gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
|
|
);
|
|
|
}
|
|
|
cute::cp_async_fence();
|
|
@@ -856,26 +861,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
|
|
|
// __syncthreads();
|
|
|
|
|
|
- // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); }
|
|
|
- if constexpr (Append_KV) {
|
|
|
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
|
|
|
- if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
|
|
|
- flash::copy_w_min_idx<Is_even_K>(
|
|
|
- tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
- );
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
if (n_block > n_block_min) {
|
|
|
// Advance gK
|
|
|
- // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); }
|
|
|
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
|
|
- if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
|
|
|
- // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); }
|
|
|
- flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
|
|
|
- gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
|
|
|
- binfo.seqlen_k_cache - (n_block - 1) * kBlockN
|
|
|
- );
|
|
|
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
|
|
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
|
|
// isn't right and we get race conditions.
|
|
|
cute::cp_async_fence();
|
|
@@ -909,20 +898,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
clear(acc_s);
|
|
|
flash::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
- if constexpr (Append_KV) {
|
|
|
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
|
|
|
- if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
|
|
|
- flash::copy_w_min_idx<Is_even_K>(
|
|
|
- tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
- );
|
|
|
- }
|
|
|
- }
|
|
|
// Advance gV
|
|
|
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
|
|
- if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
|
|
|
- flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
|
|
|
- gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
- );
|
|
|
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
|
|
cute::cp_async_fence();
|
|
|
|
|
|
flash::gemm(
|
|
@@ -932,22 +910,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
|
|
|
flash::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
- if constexpr (Append_KV) {
|
|
|
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
|
|
|
- if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
|
|
|
- flash::copy_w_min_idx<Is_even_K>(
|
|
|
- tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
- );
|
|
|
- }
|
|
|
- }
|
|
|
if (n_block > n_block_min) {
|
|
|
// Advance gK
|
|
|
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
|
|
- if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
|
|
|
- flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
|
|
|
- gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
|
|
|
- binfo.seqlen_k_cache - (n_block - 1) * kBlockN
|
|
|
- );
|
|
|
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
|
|
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
|
|
// isn't right and we get race conditions.
|
|
|
cute::cp_async_fence();
|