mask.h 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cute/tensor.hpp>
  6. #include "cutlass/fast_math.h" // For cutlass::FastDivmod
  7. #include "utils.h"
  8. namespace flash {
  9. using namespace cute;
  10. template <int kBlockM, int kBlockN, bool PackGQA, typename TiledMma, bool SwapAB=false>
  11. struct Mask {
  12. static_assert(!(PackGQA && SwapAB), "Cannot be both PackGQA and SwapAB");
  13. int const thread_idx;
  14. int const seqlen_q, seqlen_k;
  15. int const window_size_left, window_size_right, sink_token_length;
  16. cutlass::FastDivmod const qhead_per_khead_divmod;
  17. CUTLASS_DEVICE
  18. Mask(const int thread_idx, const int seqlen_q, const int seqlen_k,
  19. const int window_size_left, const int window_size_right, const int sink_token_length,
  20. cutlass::FastDivmod const &qhead_per_khead_divmod)
  21. : thread_idx(thread_idx)
  22. , seqlen_q(seqlen_q)
  23. , seqlen_k(seqlen_k)
  24. , window_size_left(window_size_left)
  25. , window_size_right(window_size_right)
  26. , sink_token_length(sink_token_length)
  27. , qhead_per_khead_divmod(qhead_per_khead_divmod)
  28. {
  29. };
  30. template <bool Seqlenk_mask=false, bool Causal_mask=false, bool Local_mask=false,
  31. typename Engine, typename Layout>
  32. CUTLASS_DEVICE
  33. void apply(Tensor<Engine, Layout> &tSrS, const int m_block, const int n_block) const {
  34. static_assert(!(Causal_mask && Local_mask), "Cannot be both causal and local");
  35. static_assert(Layout::rank == 3, "Only support 3D Tensor");
  36. if (!Seqlenk_mask && !Causal_mask && !Local_mask) { return; }
  37. auto thread_mma = TiledMma{}.get_thread_slice(thread_idx);
  38. auto thread0_mma = TiledMma{}.get_thread_slice(_0{});
  39. static constexpr int Row = !SwapAB ? 0 : 1, Col = !SwapAB ? 1 : 0;
  40. Tensor cS = cute::make_identity_tensor(Shape<Int<!SwapAB ? kBlockM : kBlockN>, Int<!SwapAB ? kBlockN : kBlockM>>{});
  41. Tensor tScS = thread_mma.partition_C(cS);
  42. Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tSrS.layout()));
  43. Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tScS.layout()));
  44. Tensor t0ScS = thread0_mma.partition_C(cS);
  45. Tensor t0ScS_rowcol = make_tensor(t0ScS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(t0ScS.layout()));
  46. // We want to use the col indices of thread0 to compare, since that is known at compile time.
  47. // So we subtract the limit by the first col index of this thread (get<Col>(tScS_rowcol(_0{}, _0{})))
  48. int const thread_col_offset = get<Col>(tScS_rowcol(_0{}, _0{}));
  49. int const seqlenk_col_limit = seqlen_k - n_block * kBlockN - thread_col_offset;
  50. if constexpr (!Causal_mask && !Local_mask) {
  51. if constexpr (Seqlenk_mask) { // Just masking based on col
  52. #pragma unroll
  53. for (int n = 0; n < size<1>(tSrS_rowcol); ++n) {
  54. if (int(get<Col>(t0ScS_rowcol(_0{}, n))) >= seqlenk_col_limit) {
  55. #pragma unroll
  56. for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { tSrS_rowcol(m, n) = -INFINITY; }
  57. }
  58. }
  59. }
  60. } else { // mask based on both row and col
  61. if constexpr (!SwapAB) {
  62. // If PackGQA, we split the work of compute divmod among threads in the same row
  63. static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{});
  64. static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0);
  65. static_assert(CUTE_STATIC_V(size<0>(tSrS_rowcol)) <= kMmaThreadsPerRow);
  66. int mma_m_idx;
  67. // Might get OOB but it's ok since we'll check it later
  68. if constexpr (PackGQA) {
  69. mma_m_idx = qhead_per_khead_divmod.divide(m_block * kBlockM + get<Row>(tScS_rowcol(thread_idx % kMmaThreadsPerRow, _0{})));
  70. }
  71. int const causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q - thread_col_offset;
  72. if constexpr (Causal_mask) {
  73. #pragma unroll
  74. for (int m = 0; m < size<0>(tSrS_rowcol); ++m) {
  75. int const row_idx = !PackGQA
  76. ? get<Row>(tScS_rowcol(m, _0{})) + m_block * kBlockM
  77. : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow);
  78. int const col_limit_right = !Seqlenk_mask
  79. ? row_idx + causal_row_offset
  80. : __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit);
  81. #pragma unroll
  82. for (int n = 0; n < size<1>(tSrS_rowcol); ++n) {
  83. if (int(get<Col>(t0ScS_rowcol(_0{}, n))) >= col_limit_right) { tSrS_rowcol(m, n) = -INFINITY; }
  84. }
  85. }
  86. } else {
  87. int const local_row_offset_right = causal_row_offset + window_size_right;
  88. int const local_row_offset_left = causal_row_offset - 1 - window_size_left;
  89. int const col_limit_sink = sink_token_length - n_block * kBlockN;
  90. #pragma unroll
  91. for (int m = 0; m < size<0>(tSrS_rowcol); ++m) {
  92. int const row_idx = !PackGQA
  93. ? get<Row>(tScS_rowcol(m, _0{})) + m_block * kBlockM
  94. : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow);
  95. int const col_limit_right = !Seqlenk_mask
  96. ? row_idx + local_row_offset_right
  97. : __viaddmin_s32(row_idx, local_row_offset_right, seqlenk_col_limit);
  98. int const col_limit_left = row_idx + local_row_offset_left;
  99. #pragma unroll
  100. for (int n = 0; n < size<1>(tSrS_rowcol); ++n) {
  101. int const col_idx = int(get<Col>(t0ScS_rowcol(m, n)));
  102. if (col_idx >= col_limit_right || (col_idx < col_limit_left && col_idx >= col_limit_sink)) { tSrS_rowcol(m, n) = -INFINITY; }
  103. }
  104. }
  105. }
  106. } else {
  107. int const thread_row_offset = get<Row>(tScS_rowcol(_0{}, _0{}));
  108. int const causal_row_offset = seqlenk_col_limit - seqlen_q + m_block * kBlockM + thread_row_offset;
  109. if constexpr (Causal_mask) {
  110. #pragma unroll
  111. for (int n = 0; n < size<1>(tSrS_rowcol); ++n) {
  112. int const col0 = int(get<Col>(t0ScS_rowcol(_0{}, n)));
  113. // If col0 is beyond the column limit, we want to mask out the entire column, by setting
  114. // row limit to be kBlockM.
  115. int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset;
  116. #pragma unroll
  117. for (int m = 0; m < size<0>(tSrS_rowcol); ++m) {
  118. if (int(get<Row>(t0ScS_rowcol(m, _0{}))) < row_limit_top) { tSrS_rowcol(m, n) = -INFINITY; }
  119. }
  120. }
  121. } else {
  122. int const col_limit_sink = sink_token_length - n_block * kBlockN - thread_col_offset;
  123. #pragma unroll
  124. for (int n = 0; n < size<1>(tSrS_rowcol); ++n) {
  125. int const col0 = int(get<Col>(t0ScS_rowcol(_0{}, n)));
  126. // If col0 is beyond the column limit, we want to mask out the entire column, by setting
  127. // row limit to be kBlockM.
  128. int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset - window_size_right;
  129. int const row_limit_bot = col0 < col_limit_sink ? kBlockM : col0 - causal_row_offset + window_size_left;
  130. #pragma unroll
  131. for (int m = 0; m < size<0>(tSrS_rowcol); ++m) {
  132. int const row_idx = int(get<Row>(t0ScS_rowcol(m, _0{})));
  133. if (row_idx < row_limit_top || row_idx > row_limit_bot) { tSrS_rowcol(m, n) = -INFINITY; }
  134. }
  135. }
  136. }
  137. }
  138. }
  139. };
  140. };
  141. } // namespace flash