/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include "cutlass/fast_math.h" // For cutlass::FastDivmod #include "utils.h" namespace flash { using namespace cute; template struct Mask { static_assert(!(PackGQA && SwapAB), "Cannot be both PackGQA and SwapAB"); int const thread_idx; int const seqlen_q, seqlen_k; int const window_size_left, window_size_right, sink_token_length; cutlass::FastDivmod const qhead_per_khead_divmod; CUTLASS_DEVICE Mask(const int thread_idx, const int seqlen_q, const int seqlen_k, const int window_size_left, const int window_size_right, const int sink_token_length, cutlass::FastDivmod const &qhead_per_khead_divmod) : thread_idx(thread_idx) , seqlen_q(seqlen_q) , seqlen_k(seqlen_k) , window_size_left(window_size_left) , window_size_right(window_size_right) , sink_token_length(sink_token_length) , qhead_per_khead_divmod(qhead_per_khead_divmod) { }; template CUTLASS_DEVICE void apply(Tensor &tSrS, const int m_block, const int n_block) const { static_assert(!(Causal_mask && Local_mask), "Cannot be both causal and local"); static_assert(Layout::rank == 3, "Only support 3D Tensor"); if (!Seqlenk_mask && !Causal_mask && !Local_mask) { return; } auto thread_mma = TiledMma{}.get_thread_slice(thread_idx); auto thread0_mma = TiledMma{}.get_thread_slice(_0{}); static constexpr int Row = !SwapAB ? 0 : 1, Col = !SwapAB ? 1 : 0; Tensor cS = cute::make_identity_tensor(Shape, Int>{}); Tensor tScS = thread_mma.partition_C(cS); Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol(tScS.layout())); Tensor t0ScS = thread0_mma.partition_C(cS); Tensor t0ScS_rowcol = make_tensor(t0ScS.data(), flash::convert_layout_acc_rowcol(t0ScS.layout())); // We want to use the col indices of thread0 to compare, since that is known at compile time. // So we subtract the limit by the first col index of this thread (get(tScS_rowcol(_0{}, _0{}))) int const thread_col_offset = get(tScS_rowcol(_0{}, _0{})); int const seqlenk_col_limit = seqlen_k - n_block * kBlockN - thread_col_offset; if constexpr (!Causal_mask && !Local_mask) { if constexpr (Seqlenk_mask) { // Just masking based on col #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { if (int(get(t0ScS_rowcol(_0{}, n))) >= seqlenk_col_limit) { #pragma unroll for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { tSrS_rowcol(m, n) = -INFINITY; } } } } } else { // mask based on both row and col if constexpr (!SwapAB) { // If PackGQA, we split the work of compute divmod among threads in the same row static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); static_assert(CUTE_STATIC_V(size<0>(tSrS_rowcol)) <= kMmaThreadsPerRow); int mma_m_idx; // Might get OOB but it's ok since we'll check it later if constexpr (PackGQA) { mma_m_idx = qhead_per_khead_divmod.divide(m_block * kBlockM + get(tScS_rowcol(thread_idx % kMmaThreadsPerRow, _0{}))); } int const causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q - thread_col_offset; if constexpr (Causal_mask) { #pragma unroll for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { int const row_idx = !PackGQA ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); int const col_limit_right = !Seqlenk_mask ? row_idx + causal_row_offset : __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit); #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { if (int(get(t0ScS_rowcol(_0{}, n))) >= col_limit_right) { tSrS_rowcol(m, n) = -INFINITY; } } } } else { int const local_row_offset_right = causal_row_offset + window_size_right; int const local_row_offset_left = causal_row_offset - 1 - window_size_left; int const col_limit_sink = sink_token_length - n_block * kBlockN; #pragma unroll for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { int const row_idx = !PackGQA ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); int const col_limit_right = !Seqlenk_mask ? row_idx + local_row_offset_right : __viaddmin_s32(row_idx, local_row_offset_right, seqlenk_col_limit); int const col_limit_left = row_idx + local_row_offset_left; #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { int const col_idx = int(get(t0ScS_rowcol(m, n))); if (col_idx >= col_limit_right || (col_idx < col_limit_left && col_idx >= col_limit_sink)) { tSrS_rowcol(m, n) = -INFINITY; } } } } } else { int const thread_row_offset = get(tScS_rowcol(_0{}, _0{})); int const causal_row_offset = seqlenk_col_limit - seqlen_q + m_block * kBlockM + thread_row_offset; if constexpr (Causal_mask) { #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { int const col0 = int(get(t0ScS_rowcol(_0{}, n))); // If col0 is beyond the column limit, we want to mask out the entire column, by setting // row limit to be kBlockM. int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset; #pragma unroll for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { if (int(get(t0ScS_rowcol(m, _0{}))) < row_limit_top) { tSrS_rowcol(m, n) = -INFINITY; } } } } else { int const col_limit_sink = sink_token_length - n_block * kBlockN - thread_col_offset; #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { int const col0 = int(get(t0ScS_rowcol(_0{}, n))); // If col0 is beyond the column limit, we want to mask out the entire column, by setting // row limit to be kBlockM. int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset - window_size_right; int const row_limit_bot = col0 < col_limit_sink ? kBlockM : col0 - causal_row_offset + window_size_left; #pragma unroll for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { int const row_idx = int(get(t0ScS_rowcol(m, _0{}))); if (row_idx < row_limit_top || row_idx > row_limit_bot) { tSrS_rowcol(m, n) = -INFINITY; } } } } } } }; }; } // namespace flash