|
@@ -24,6 +24,27 @@ using namespace cute;
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
+template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
|
|
|
+__forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bidb, const int bidh, const int m_block, const BlockInfo</*Varlen=*/!Is_even_MN> &binfo) {
|
|
|
+ // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.
|
|
|
+ // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.
|
|
|
+ // Otherwise, it's written as (h, b, seqlen_q).
|
|
|
+ const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped;
|
|
|
+ auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0;
|
|
|
+ auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + lse_offset);
|
|
|
+
|
|
|
+ auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q);
|
|
|
+ auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : (
|
|
|
+ params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1)
|
|
|
+ );
|
|
|
+
|
|
|
+ auto lse_layout = make_layout(lse_shape, lse_stride);
|
|
|
+ Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout);
|
|
|
+ auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _);
|
|
|
+ return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block));
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
|
|
inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
|
|
|
|
|
@@ -74,10 +95,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
|
|
|
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
|
|
|
- Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
|
|
|
- make_shape(params.b, params.h, params.seqlen_q),
|
|
|
- make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
|
|
|
- Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
|
|
|
+
|
|
|
+ Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
|
|
|
|
|
|
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
|
|
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
|
@@ -424,10 +443,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
|
|
|
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
|
|
|
- Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
|
|
|
- make_shape(params.b, params.h, params.seqlen_q),
|
|
|
- make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
|
|
|
- Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
|
|
|
+ Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
|
|
|
|
|
|
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
|
|
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
|
@@ -986,7 +1002,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
|
|
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
|
|
|
+ m_block * kBlockM) * params.d_rounded;
|
|
|
- const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
|
|
+ const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ?
|
|
|
+ ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)
|
|
|
+ ) + m_block * kBlockM;
|
|
|
|
|
|
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
|
|
|
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
@@ -1092,21 +1110,36 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
|
|
const int tidx = threadIdx.x;
|
|
|
const int bidx = blockIdx.x;
|
|
|
|
|
|
+ const index_t lse_size = params.b * params.h * params.seqlen_q;
|
|
|
+
|
|
|
const index_t row_offset_lse = bidx * kBlockM;
|
|
|
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
|
|
|
Shape<Int<kMaxSplits>, Int<kBlockM>>{},
|
|
|
- make_stride(params.b * params.h * params.seqlen_q, _1{}));
|
|
|
+ make_stride(lse_size, _1{}));
|
|
|
+
|
|
|
+ // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
|
|
|
+ // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
|
|
|
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
|
|
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
|
|
+
|
|
|
+ // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
|
|
|
+ Layout flat_layout = make_layout(lse_size);
|
|
|
+ Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b));
|
|
|
+ auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
|
|
|
+ Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
|
|
|
+ Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));
|
|
|
+
|
|
|
+ Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout);
|
|
|
+
|
|
|
constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
|
|
|
|
|
|
- // Read the LSE values from gmem and store them in shared memory, then tranpose them.
|
|
|
+ // Read the LSE values from gmem and store them in shared memory, then transpose them.
|
|
|
constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
|
|
|
#pragma unroll
|
|
|
for (int l = 0; l < kNLsePerThread; ++l) {
|
|
|
const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
|
|
|
const int col = tidx % kBlockM;
|
|
|
- ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
|
|
|
+ ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
|
|
|
if (row < kMaxSplits) { sLSE[row][col] = lse; }
|
|
|
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
|
|
|
}
|
|
@@ -1145,7 +1178,16 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
|
|
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
|
|
|
ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
|
|
|
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
|
|
|
- if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; }
|
|
|
+ if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
|
|
|
+ if (params.unpadded_lse) {
|
|
|
+ const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
|
|
|
+ if (lse_offset < lse_size) {
|
|
|
+ gLSE_unpadded(lse_offset) = lse_logsum;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
|
|
|
+ }
|
|
|
+ }
|
|
|
// Store the scales exp(lse - lse_logsum) in shared memory.
|
|
|
#pragma unroll
|
|
|
for (int l = 0; l < kNLsePerThread; ++l) {
|