|
@@ -617,7 +617,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
-template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params>
|
|
|
+template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
|
|
|
inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
|
|
|
|
|
|
using Element = typename Kernel_traits::Element;
|
|
@@ -635,7 +635,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
|
|
constexpr int kNWarps = Kernel_traits::kNWarps;
|
|
|
|
|
|
+ using GmemTiledCopyO = std::conditional_t<
|
|
|
+ !Split,
|
|
|
+ typename Kernel_traits::GmemTiledCopyOaccum,
|
|
|
+ typename Kernel_traits::GmemTiledCopyO
|
|
|
+ >;
|
|
|
+ using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
|
|
|
+
|
|
|
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_q = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q)); }
|
|
|
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
|
|
|
|
|
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
|
|
@@ -649,19 +658,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
// We exit early and write 0 to gOaccum and -inf to gLSEaccum.
|
|
|
// Otherwise we might read OOB elements from gK and gV,
|
|
|
// or get wrong results when we combine gOaccum from different blocks.
|
|
|
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
|
|
+ + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
|
|
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
|
|
|
+ m_block * kBlockM) * params.d_rounded;
|
|
|
const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
|
|
- Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
|
|
|
- Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
- Stride<Int<kHeadDim>, _1>{});
|
|
|
- Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
|
|
|
+ Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
|
|
|
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
+ make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
|
|
|
+ Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
|
|
|
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
|
|
|
|
|
- typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
|
|
+ GmemTiledCopyO gmem_tiled_copy_Oaccum;
|
|
|
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
|
|
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
|
|
|
- Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
|
|
|
+ Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
|
|
|
clear(tOrOaccum);
|
|
|
// Construct identity layout for sO
|
|
|
Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
@@ -679,7 +690,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
#pragma unroll
|
|
|
for (int m = 0; m < size<1>(tOgOaccum); ++m) {
|
|
|
const int row = get<0>(tOcO(0, m, 0));
|
|
|
- if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = -INFINITY; }
|
|
|
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; }
|
|
|
}
|
|
|
return;
|
|
|
}
|
|
@@ -695,6 +706,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
|
|
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
|
|
|
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
|
|
+ const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
|
|
|
+ + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
|
|
|
+ const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
|
|
|
+ + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
|
|
|
|
|
|
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
|
|
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
@@ -702,15 +717,26 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
|
|
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
make_stride(params.k_row_stride, _1{}));
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
|
|
|
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
|
|
|
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
make_stride(params.v_row_stride, _1{}));
|
|
|
+ // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
|
|
|
+ // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
|
|
|
+ // This maps to accessing the first 64 rows of knew_ptr.
|
|
|
+ Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
|
|
|
+ + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
+ make_stride(params.knew_row_stride, _1{}));
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
|
|
|
+ Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
|
|
|
+ + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
+ make_stride(params.vnew_row_stride, _1{}));
|
|
|
|
|
|
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
|
|
typename Kernel_traits::SmemLayoutQ{});
|
|
|
- // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
|
|
|
- Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
|
|
|
- typename Kernel_traits::SmemLayoutKV{});
|
|
|
+ Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
|
|
|
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
|
|
|
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
|
|
|
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
|
|
@@ -721,8 +747,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
|
|
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
|
|
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
|
|
|
+ Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
|
|
|
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
|
|
|
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
|
|
|
+ Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
|
|
|
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
|
|
|
|
|
|
typename Kernel_traits::TiledMma tiled_mma;
|
|
@@ -787,32 +815,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
|
|
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
|
|
binfo.actual_seqlen_q - m_block * kBlockM);
|
|
|
- if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
|
|
|
-
|
|
|
- if (Kernel_traits::Share_Q_K_smem) {
|
|
|
- flash::cp_async_wait<0>();
|
|
|
- __syncthreads();
|
|
|
- Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
|
|
- CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
|
|
- cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
|
|
- __syncthreads();
|
|
|
- }
|
|
|
|
|
|
int n_block = n_block_max - 1;
|
|
|
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
|
|
- flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
|
|
- binfo.actual_seqlen_k - n_block * kBlockN);
|
|
|
+ flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K>(
|
|
|
+ gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV,
|
|
|
+ binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
+ );
|
|
|
cute::cp_async_fence();
|
|
|
- // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
|
|
|
- // __syncthreads();
|
|
|
|
|
|
- if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
|
|
|
- flash::cp_async_wait<1>();
|
|
|
- __syncthreads();
|
|
|
- Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
|
|
- CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
|
|
- cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
|
|
- }
|
|
|
+ // flash::cp_async_wait<0>();
|
|
|
+ // __syncthreads();
|
|
|
+ // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
|
|
|
+ // __syncthreads();
|
|
|
|
|
|
clear(acc_o);
|
|
|
|
|
@@ -834,19 +849,37 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
flash::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
|
|
|
+ if constexpr (Append_KV) {
|
|
|
+ // if (cute::thread0()) { print(tKgK); }
|
|
|
+ // if (cute::thread0()) { print(tKsK); }
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
|
|
|
+ if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
|
|
|
+ flash::copy_w_min_idx<Is_even_K>(
|
|
|
+ tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
+ );
|
|
|
+ }
|
|
|
+ // __syncthreads();
|
|
|
+ // if (cute::thread0()) { print(tKgK); }
|
|
|
+ // __syncthreads();
|
|
|
+ }
|
|
|
+
|
|
|
// Advance gV
|
|
|
if (masking_step > 0) {
|
|
|
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
|
|
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
|
|
+ if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
|
|
|
+ flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
|
|
|
+ gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
+ );
|
|
|
} else {
|
|
|
// Clear the smem tiles to account for predicated off loads
|
|
|
- flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
|
|
- gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
|
|
+ flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
|
|
+ gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV,
|
|
|
+ binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
);
|
|
|
}
|
|
|
cute::cp_async_fence();
|
|
|
|
|
|
- flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
|
|
+ flash::gemm(
|
|
|
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
|
|
smem_thr_copy_Q, smem_thr_copy_K
|
|
|
);
|
|
@@ -869,19 +902,39 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
|
|
|
flash::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
+ // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
|
|
|
+ // __syncthreads();
|
|
|
+
|
|
|
+ // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); }
|
|
|
+ if constexpr (Append_KV) {
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
|
|
|
+ if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
|
|
|
+ flash::copy_w_min_idx<Is_even_K>(
|
|
|
+ tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
if (n_block > n_block_min) {
|
|
|
// Advance gK
|
|
|
+ // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); }
|
|
|
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
|
|
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
|
|
+ if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
|
|
|
+ // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); }
|
|
|
+ flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
|
|
|
+ gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
|
|
|
+ binfo.seqlen_k_cache - (n_block - 1) * kBlockN
|
|
|
+ );
|
|
|
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
|
|
// isn't right and we get race conditions.
|
|
|
cute::cp_async_fence();
|
|
|
}
|
|
|
|
|
|
- // TODO: when we have key_padding_mask we'll need to Check_inf
|
|
|
+ // We have key_padding_mask so we'll need to Check_inf
|
|
|
masking_step == 0
|
|
|
- ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
|
|
- : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
|
|
+ ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
|
|
+ : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
|
|
+ // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
|
|
|
|
|
|
// Convert scores from fp32 to fp16/bf16
|
|
|
Tensor rP = flash::convert_type<Element>(scores);
|
|
@@ -905,22 +958,45 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
clear(acc_s);
|
|
|
flash::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
+ if constexpr (Append_KV) {
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
|
|
|
+ if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
|
|
|
+ flash::copy_w_min_idx<Is_even_K>(
|
|
|
+ tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
// Advance gV
|
|
|
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
|
|
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
|
|
+ if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
|
|
|
+ flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
|
|
|
+ gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
+ );
|
|
|
cute::cp_async_fence();
|
|
|
|
|
|
- flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
|
|
+ flash::gemm(
|
|
|
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
|
|
smem_thr_copy_Q, smem_thr_copy_K
|
|
|
);
|
|
|
|
|
|
flash::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
+ if constexpr (Append_KV) {
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
|
|
|
+ if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
|
|
|
+ flash::copy_w_min_idx<Is_even_K>(
|
|
|
+ tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
if (n_block > n_block_min) {
|
|
|
// Advance gK
|
|
|
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
|
|
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
|
|
+ if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
|
|
|
+ flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
|
|
|
+ gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
|
|
|
+ binfo.seqlen_k_cache - (n_block - 1) * kBlockN
|
|
|
+ );
|
|
|
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
|
|
// isn't right and we get race conditions.
|
|
|
cute::cp_async_fence();
|
|
@@ -942,49 +1018,60 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
|
|
|
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
|
|
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
|
|
+ // if (cute::thread0()) { print(acc_o_rowcol); }
|
|
|
Tensor lse = make_fragment_like(scores_sum);
|
|
|
#pragma unroll
|
|
|
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
|
|
float sum = scores_sum(mi);
|
|
|
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
|
|
- lse(mi) = (sum == 0.f || sum != sum) ? -INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
|
|
|
+ lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum);
|
|
|
float scale = inv_sum;
|
|
|
#pragma unroll
|
|
|
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
|
|
}
|
|
|
-
|
|
|
+ // if (cute::thread0()) { print(lse); }
|
|
|
// if (cute::thread0()) { print(acc_o_rowcol); }
|
|
|
|
|
|
- Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
|
|
+ Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
|
|
// Partition sO to match the accumulator partitioning
|
|
|
- auto smem_tiled_copy_Oaccum = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomOaccum{}, tiled_mma);
|
|
|
+ using SmemTiledCopyO = std::conditional_t<
|
|
|
+ !Split,
|
|
|
+ typename Kernel_traits::SmemCopyAtomO,
|
|
|
+ typename Kernel_traits::SmemCopyAtomOaccum
|
|
|
+ >;
|
|
|
+ auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
|
|
|
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
|
|
- Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(acc_o); // ((Atom,AtomNum), MMA_M, MMA_N)
|
|
|
+ Tensor rO = flash::convert_type<ElementO>(acc_o);
|
|
|
+ Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
|
|
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
|
|
|
|
|
- // sO has the same size as sQ, so we don't need to sync here.
|
|
|
- if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
|
|
|
+ // sOaccum is larger than sQ, so we need to syncthreads here
|
|
|
+ // TODO: allocate enough smem for sOaccum
|
|
|
+ if constexpr (Split) { __syncthreads(); }
|
|
|
|
|
|
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
|
|
|
|
|
|
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
|
|
+ + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
|
|
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
|
|
|
+ m_block * kBlockM) * params.d_rounded;
|
|
|
const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
|
|
|
|
|
- Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
|
|
|
+ Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
|
|
|
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
- Stride<Int<kHeadDim>, _1>{});
|
|
|
- Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
|
|
|
+ make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
|
|
|
+ Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
|
|
|
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
|
|
+ // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); }
|
|
|
|
|
|
- typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
|
|
+ GmemTiledCopyO gmem_tiled_copy_Oaccum;
|
|
|
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
|
|
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
|
|
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
- Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
|
|
|
+ Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
|
|
|
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
|
|
|
|
|
|
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
@@ -1014,6 +1101,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
|
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
|
|
);
|
|
|
+ // __syncthreads();
|
|
|
+ // if (cute::thread0()) { print(tOgOaccum); }
|
|
|
}
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
@@ -1039,16 +1128,16 @@ inline __device__ void compute_attn(const Params ¶ms) {
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
-template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params>
|
|
|
+template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
|
|
|
inline __device__ void compute_attn_splitkv(const Params ¶ms) {
|
|
|
const int m_block = blockIdx.x;
|
|
|
// The block index for the batch.
|
|
|
- const int bidb = blockIdx.z / params.h;
|
|
|
+ const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
|
|
|
// The block index for the head.
|
|
|
- const int bidh = blockIdx.z - bidb * params.h;
|
|
|
- const int n_split_idx = blockIdx.y;
|
|
|
- const int num_n_splits = gridDim.y;
|
|
|
- flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
|
|
|
+ const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
|
|
|
+ const int n_split_idx = Split ? blockIdx.y : 0;
|
|
|
+ const int num_n_splits = Split ? gridDim.y : 1;
|
|
|
+ flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
|
|
|
}
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|