mask.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cute/tensor.hpp>
  6. namespace flash {
  7. using namespace cute;
  8. template <typename Engine, typename Layout>
  9. __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
  10. const int col_idx_offset_ = 0) {
  11. // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
  12. static_assert(Layout::rank == 2, "Only support 2D Tensor");
  13. const int lane_id = threadIdx.x % 32;
  14. const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
  15. #pragma unroll
  16. for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
  17. const int col_idx_base = col_idx_offset + nj * 8;
  18. #pragma unroll
  19. for (int j = 0; j < size<1, 0>(tensor); ++j) {
  20. const int col_idx = col_idx_base + j;
  21. if (col_idx >= max_seqlen_k) {
  22. // Without the "make_coord" we get wrong results
  23. #pragma unroll
  24. for (int mi = 0; mi < size<0>(tensor); ++mi) {
  25. tensor(mi, make_coord(j, nj)) = -INFINITY;
  26. }
  27. }
  28. }
  29. }
  30. }
  31. template <bool HasWSLeft=true, typename Engine, typename Layout>
  32. __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
  33. const int max_seqlen_k, const int row_idx_offset,
  34. const int max_seqlen_q, const int warp_row_stride,
  35. const int window_size_left, const int window_size_right) {
  36. // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
  37. static_assert(Layout::rank == 2, "Only support 2D Tensor");
  38. const int lane_id = threadIdx.x % 32;
  39. const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
  40. #pragma unroll
  41. for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
  42. const int row_idx_base = row_idx_offset + mi * warp_row_stride;
  43. #pragma unroll
  44. for (int i = 0; i < size<0, 0>(tensor); ++i) {
  45. const int row_idx = row_idx_base + i * 8;
  46. const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
  47. const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
  48. #pragma unroll
  49. for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
  50. const int col_idx_base = col_idx_offset + nj * 8;
  51. #pragma unroll
  52. for (int j = 0; j < size<1, 0>(tensor); ++j) {
  53. const int col_idx = col_idx_base + j;
  54. if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
  55. tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
  56. }
  57. }
  58. }
  59. // if (cute::thread0()) {
  60. // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
  61. // print(tensor(make_coord(i, mi), _));
  62. // // print(tensor(_, j + nj * size<1, 0>(tensor)));
  63. // }
  64. }
  65. }
  66. }
  67. template <typename Engine, typename Layout>
  68. __forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
  69. const int max_seqlen_k, const int row_idx_offset,
  70. const int max_seqlen_q, const int warp_row_stride) {
  71. // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
  72. apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
  73. max_seqlen_q, warp_row_stride, -1, 0);
  74. }
  75. template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  76. __forceinline__ __device__ void apply_mask_causal_w_idx(
  77. Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
  78. const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
  79. {
  80. // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
  81. static_assert(Layout0::rank == 2, "Only support 2D Tensor");
  82. static_assert(Layout1::rank == 2, "Only support 2D Tensor");
  83. CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
  84. CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
  85. #pragma unroll
  86. for (int mi = 0; mi < size<0>(tensor); ++mi) {
  87. const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
  88. #pragma unroll
  89. for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
  90. if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
  91. tensor(mi, ni) = -INFINITY;
  92. }
  93. }
  94. // if (cute::thread0()) {
  95. // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
  96. // print(tensor(_, make_coord(j, ni)));
  97. // // print(tensor(_, j + ni * size<1, 0>(tensor)));
  98. // }
  99. }
  100. }
  101. template <bool Is_causal, bool Is_local, bool Has_alibi>
  102. struct Mask {
  103. const int max_seqlen_k, max_seqlen_q;
  104. const int window_size_left, window_size_right;
  105. const float alibi_slope;
  106. __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
  107. const int window_size_left, const int window_size_right,
  108. const float alibi_slope=0.f)
  109. : max_seqlen_k(max_seqlen_k)
  110. , max_seqlen_q(max_seqlen_q)
  111. , window_size_left(window_size_left)
  112. , window_size_right(window_size_right)
  113. , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
  114. };
  115. // Causal_mask: whether this particular iteration needs causal masking
  116. template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
  117. __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
  118. const int col_idx_offset_,
  119. const int row_idx_offset,
  120. const int warp_row_stride) {
  121. static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
  122. static_assert(Layout::rank == 3, "Only support 3D Tensor");
  123. static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
  124. static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
  125. // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
  126. if constexpr (Need_masking) {
  127. // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
  128. Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
  129. // Do we need both row and column indices, or just column incides?
  130. static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
  131. const int lane_id = threadIdx.x % 32;
  132. const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
  133. if constexpr (Col_idx_only) {
  134. #pragma unroll
  135. for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
  136. const int col_idx_base = col_idx_offset + nj * 8;
  137. #pragma unroll
  138. for (int j = 0; j < size<1, 0>(tensor); ++j) {
  139. const int col_idx = col_idx_base + j;
  140. #pragma unroll
  141. for (int mi = 0; mi < size<0>(tensor); ++mi) {
  142. // No causal, no local
  143. if constexpr (Has_alibi) {
  144. tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
  145. }
  146. if constexpr (!Is_even_MN) {
  147. if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
  148. }
  149. }
  150. }
  151. }
  152. } else {
  153. #pragma unroll
  154. for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
  155. const int row_idx_base = row_idx_offset + mi * warp_row_stride;
  156. #pragma unroll
  157. for (int i = 0; i < size<0, 0>(tensor); ++i) {
  158. const int row_idx = row_idx_base + i * 8;
  159. const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
  160. const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
  161. #pragma unroll
  162. for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
  163. const int col_idx_base = col_idx_offset + nj * 8;
  164. #pragma unroll
  165. for (int j = 0; j < size<1, 0>(tensor); ++j) {
  166. const int col_idx = col_idx_base + j;
  167. if constexpr (Has_alibi) {
  168. if constexpr (Is_causal) {
  169. tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
  170. } else {
  171. tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
  172. }
  173. }
  174. if constexpr (Causal_mask) {
  175. if (col_idx >= col_idx_limit_right) {
  176. tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
  177. }
  178. }
  179. if constexpr (Is_local) {
  180. if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
  181. tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
  182. }
  183. }
  184. if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
  185. // Causal and Local already handles MN masking
  186. if (col_idx >= max_seqlen_k) {
  187. tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
  188. }
  189. }
  190. }
  191. }
  192. }
  193. }
  194. }
  195. }
  196. };
  197. };
  198. } // namespace flash