|
@@ -91,7 +91,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
|
|
|
|
|
|
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
|
|
- if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
|
|
|
+ if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
|
|
|
|
|
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
|
|
|
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
|
|
@@ -101,50 +101,50 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
|
|
|
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
|
|
|
// }
|
|
|
- // We exit early and write 0 to gO and gLSE.
|
|
|
- // Otherwise we might read OOB elements from gK and gV.
|
|
|
- if (n_block_max <= n_block_min) {
|
|
|
- // Save seed and offset for backward. If we don't have this here, the 0-th thread block might
|
|
|
- // exit early and no one saves the rng state.
|
|
|
- if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
|
|
|
- auto seeds = at::cuda::philox::unpack(params.philox_args);
|
|
|
- params.rng_state[0] = std::get<0>(seeds);
|
|
|
- params.rng_state[1] = std::get<1>(seeds);
|
|
|
- }
|
|
|
- const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
|
|
- + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
|
|
- const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
|
|
- Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
|
|
- Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
- make_stride(params.o_row_stride, _1{}));
|
|
|
- Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
|
|
- Shape<Int<kBlockM>>{}, Stride<_1>{});
|
|
|
-
|
|
|
- typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
|
|
- auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
|
|
- Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
|
|
- Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
|
|
- clear(tOrO);
|
|
|
- // Construct identity layout for sO
|
|
|
- Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
|
- // Repeat the partitioning with identity layouts
|
|
|
- Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
|
|
|
- Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
|
|
- if (!Is_even_K) {
|
|
|
- #pragma unroll
|
|
|
- for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
|
|
- }
|
|
|
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
|
- flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
|
- gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
|
|
- );
|
|
|
+ }
|
|
|
+ // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
|
|
|
+ // Otherwise we might read OOB elements from gK and gV.
|
|
|
+ if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
|
|
|
+ // Save seed and offset for backward. If we don't have this here, the 0-th thread block might
|
|
|
+ // exit early and no one saves the rng state.
|
|
|
+ if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
|
|
|
+ auto seeds = at::cuda::philox::unpack(params.philox_args);
|
|
|
+ params.rng_state[0] = std::get<0>(seeds);
|
|
|
+ params.rng_state[1] = std::get<1>(seeds);
|
|
|
+ }
|
|
|
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
|
|
+ + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
|
|
+ const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
|
|
+ Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
|
|
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
+ make_stride(params.o_row_stride, _1{}));
|
|
|
+ Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
|
|
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
|
|
|
+
|
|
|
+ typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
|
|
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
|
|
+ Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
|
|
+ Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
|
|
+ clear(tOrO);
|
|
|
+ // Construct identity layout for sO
|
|
|
+ Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
|
+ // Repeat the partitioning with identity layouts
|
|
|
+ Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
|
|
|
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
|
|
+ if (!Is_even_K) {
|
|
|
#pragma unroll
|
|
|
- for (int m = 0; m < size<1>(tOgO); ++m) {
|
|
|
- const int row = get<0>(tOcO(0, m, 0));
|
|
|
- if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
|
|
|
- }
|
|
|
- return;
|
|
|
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
|
|
}
|
|
|
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
|
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
|
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
|
|
+ );
|
|
|
+ #pragma unroll
|
|
|
+ for (int m = 0; m < size<1>(tOgO); ++m) {
|
|
|
+ const int row = get<0>(tOcO(0, m, 0));
|
|
|
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
|
|
|
+ }
|
|
|
+ return;
|
|
|
}
|
|
|
// if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
|
|
|
|