1
0

epilogue_fwd_sm90_tma.hpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/fast_math.h> // For FastDivMod
  7. #include "cute/tensor.hpp"
  8. #include "cutlass/gemm/collective/builders/sm90_common.inl"
  9. #include "cutlass/epilogue/collective/builders/sm90_common.inl"
  10. #include "seqlen.h"
  11. #include "named_barrier.hpp"
  12. #include "pack_gqa.h"
  13. #include "utils.h"
  14. namespace flash {
  15. using namespace cute;
  16. template <class TileShape_MNK_, class ClusterShape_, class Element_, int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool FP8PermuteCol=false>
  17. struct CollectiveEpilogueFwd {
  18. using TileShape_MNK = TileShape_MNK_;
  19. using ClusterShape = ClusterShape_;
  20. using Element = Element_;
  21. static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
  22. static constexpr bool Varlen = Varlen_;
  23. static constexpr bool PackGQA = PackGQA_;
  24. static constexpr bool Use_smem = sizeof(Element) <= 2;
  25. static constexpr bool Use_TMA_O = !Varlen && Use_smem && !PackGQA;
  26. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  27. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  28. using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
  29. // These are for storing the output tensor without TMA (e.g., for setting output to zero)
  30. static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);
  31. static_assert(kHeadDim % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
  32. // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements
  33. // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times
  34. // we need to call divmod.
  35. static constexpr int kBytePerRow = kHeadDim * sizeof(Element);
  36. static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
  37. // static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
  38. // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerStore, NumEpilogueThreads);
  39. static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore;
  40. // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp
  41. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);
  42. static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
  43. using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  44. Stride<Int<kGmemThreadsPerRow>, _1>>;
  45. static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow");
  46. using GmemTiledCopyO = decltype(
  47. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  48. GmemLayoutAtom{},
  49. Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); // Val layout, 8 or 16 vals per store
  50. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  51. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  52. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  53. using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch, num_splits)
  54. using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
  55. using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits)
  56. // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
  57. using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
  58. using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;
  59. // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
  60. using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
  61. using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;
  62. // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
  63. using CopyOpR2S = decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>());
  64. using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
  65. // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{});
  66. // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment");
  67. // struct TensorStorage : cute::aligned_struct<SmemAlignmentO> {
  68. // cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0, SmemAlignmentO> smem_o;
  69. // };
  70. struct TensorStorage : cute::aligned_struct<128> {
  71. cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0> smem_o;
  72. };
  73. using TMA_O = decltype(make_tma_copy(
  74. GmemTiledCopyOTMA{},
  75. make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
  76. SmemLayoutO{},
  77. select<0, 2>(TileShape_MNK{}),
  78. _1{})); // no mcast for O
  79. // Host side kernel arguments
  80. struct Arguments {
  81. Element* ptr_O;
  82. ShapeO const shape_O;
  83. StrideO const stride_O;
  84. float* ptr_LSE;
  85. StrideLSE const stride_LSE;
  86. int32_t const nheads_kv;
  87. int const* cu_seqlens = nullptr;
  88. int const* seqused = nullptr;
  89. };
  90. // Device side kernel params
  91. struct Params {
  92. Element* ptr_O;
  93. ShapeO const shape_O;
  94. StrideO const stride_O;
  95. ShapeOPacked const shape_O_packed;
  96. StrideOPacked const stride_O_packed;
  97. float* ptr_LSE;
  98. StrideLSE const stride_LSE;
  99. ShapeLSEPacked const shape_LSE_packed;
  100. StrideLSEPacked const stride_LSE_packed;
  101. cutlass::FastDivmod qhead_per_khead_divmod;
  102. TMA_O tma_store_O;
  103. int const* cu_seqlens = nullptr;
  104. int const* seqused = nullptr;
  105. };
  106. static Params
  107. to_underlying_arguments(Arguments const& args) {
  108. Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
  109. TMA_O tma_store_O = make_tma_copy(
  110. GmemTiledCopyOTMA{},
  111. mO,
  112. SmemLayoutO{},
  113. select<0, 2>(TileShape_MNK{}),
  114. _1{}); // no mcast for O
  115. // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)
  116. int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);
  117. auto const shape_O_packed = cute::conditional_return<!PackGQA>(
  118. args.shape_O,
  119. make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
  120. );
  121. auto const stride_O_packed = cute::conditional_return<!PackGQA>(
  122. args.stride_O,
  123. make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O))
  124. );
  125. // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)
  126. auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(
  127. select<0, 2, 3, 4>(args.shape_O),
  128. make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
  129. );
  130. auto const stride_LSE_packed = cute::conditional_return<!PackGQA>(
  131. args.stride_LSE,
  132. make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE))
  133. );
  134. return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed,
  135. args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed,
  136. cutlass::FastDivmod(qhead_per_khead),
  137. tma_store_O, args.cu_seqlens, args.seqused};
  138. }
  139. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  140. CUTLASS_DEVICE
  141. static void prefetch_tma_descriptors(Params const& params) {
  142. if constexpr (Use_TMA_O) {
  143. cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor());
  144. }
  145. }
  146. template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
  147. CUTLASS_DEVICE void
  148. store(Params const& params,
  149. FrgTensorO const& tOrO,
  150. FrgTensorLSE const& lse,
  151. SharedStorage& shared_storage,
  152. TiledMma tiled_mma,
  153. int thread_idx,
  154. cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
  155. ) {
  156. auto [m_block, bidh, bidb, split_idx] = block_coord;
  157. Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});
  158. // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO);
  159. Tensor tOrO_out = make_tensor_like<Element>(tOrO);
  160. flash::convert_type_out(tOrO, tOrO_out);
  161. if constexpr (FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4)) { flash::permute_output_fp8_Vcolmajor(tOrO_out); }
  162. // Make sure all WGs have finished reading V
  163. // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that
  164. // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with
  165. // cp.async if we need).
  166. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  167. // Step 1: Write O from rmem -> smem
  168. if constexpr (Use_smem) {
  169. auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
  170. auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
  171. Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  172. Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  173. // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  174. cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
  175. if constexpr (!Varlen && !PackGQA) {
  176. cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
  177. cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  178. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  179. } else {
  180. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  181. }
  182. } else {
  183. #pragma unroll
  184. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  185. shared_storage.pipelines.barrier_O.arrive(cta_id);
  186. }
  187. }
  188. flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
  189. bool is_varlen = Varlen && params.cu_seqlens;
  190. int offset_o = seqlen_info.offset;
  191. int seqlen_o = seqlen_info.seqlen;
  192. // Step 2: Write LSE from rmem -> gmem
  193. auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
  194. // (MMA,MMA_M,MMA_K)
  195. Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})));
  196. static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
  197. static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
  198. // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
  199. Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
  200. CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
  201. using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>;
  202. Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx);
  203. // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
  204. if constexpr (!PackGQA) {
  205. #pragma unroll
  206. for (int mi = 0; mi < size(lse); ++mi) {
  207. int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));
  208. if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); }
  209. }
  210. } else {
  211. PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
  212. }
  213. // Step 3: Write O from smem -> gmem
  214. if constexpr (Use_TMA_O) {
  215. Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
  216. Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  217. auto block_tma_O = params.tma_store_O.get_slice(_0{});
  218. Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
  219. Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
  220. int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
  221. if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
  222. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  223. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  224. if (cute::elect_one_sync()) {
  225. cute::copy(params.tma_store_O, tOsO, tOgO);
  226. tma_store_arrive();
  227. tma_store_wait<0>();
  228. #pragma unroll
  229. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  230. shared_storage.pipelines.barrier_O.arrive(cta_id);
  231. }
  232. }
  233. }
  234. } else { // Don't use TMA since we don't want to overwrite the output of another sequence
  235. Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
  236. Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  237. // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }
  238. if constexpr (Use_smem) {
  239. GmemTiledCopyO gmem_tiled_copy_O;
  240. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  241. Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  242. // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  243. Tensor tOrO = make_fragment_like(tOsO);
  244. cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
  245. cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v
  246. #pragma unroll
  247. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  248. shared_storage.pipelines.barrier_O.arrive(cta_id);
  249. }
  250. if constexpr (!PackGQA) {
  251. // (BLK_M,BLK_K) -> (blk_m,blk_k)
  252. Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})));
  253. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO)));
  254. #pragma unroll
  255. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
  256. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  257. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  258. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  259. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
  260. );
  261. } else {
  262. // If PackGQA, we split the work of compute O_ptr among threads in the same row
  263. PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
  264. }
  265. } else {
  266. // We already arrived on barrier_O earlier
  267. if constexpr (!PackGQA) {
  268. static constexpr int kGmemElemsPerStoreDirect = 2;
  269. cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element> gmem_copy_direct;
  270. // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  271. Tensor tOrO_rowcol = make_tensor(tOrO_out.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));
  272. Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
  273. Tensor tOgO = thread_mma.partition_C(gO);
  274. Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout()));
  275. Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
  276. // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the col indices.
  277. Tensor taccOcO_col = taccOcO(make_coord(_, _0{}, _), _0{}, _);
  278. #pragma unroll
  279. for (int m = 0; m < size(taccOcO_row); ++m) {
  280. if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) {
  281. #pragma unroll
  282. for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) {
  283. if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) {
  284. cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k));
  285. }
  286. }
  287. }
  288. }
  289. } else {
  290. PackGQAt::store_O_direct(mO, tOrO_out, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
  291. }
  292. }
  293. }
  294. }
  295. CUTLASS_DEVICE void
  296. store_tail() {
  297. // Don't need to do tma_store_wait<0>() here since we already did in @store
  298. }
  299. // Write 0 to output and -inf to LSE
  300. template <bool Clear_O=true>
  301. CUTLASS_DEVICE void
  302. store_zero(
  303. Params const& params,
  304. int thread_idx,
  305. cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
  306. ) {
  307. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  308. auto [m_block, bidh, bidb, split_idx] = block_coord;
  309. flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
  310. bool const is_varlen = Varlen && params.cu_seqlens;
  311. int offset_o = seqlen_info.offset;
  312. int seqlen_o = seqlen_info.seqlen;
  313. int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
  314. Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
  315. Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx);
  316. Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block));
  317. static_assert(kBlockM <= NumEpilogueThreads);
  318. if (thread_idx < kBlockM) {
  319. const int row = m_block * kBlockM + thread_idx;
  320. if constexpr (!PackGQA) {
  321. if (row < seqlen_o) { mLSE(row) = -INFINITY; }
  322. } else {
  323. if (row < seqlen_o * qhead_per_khead) {
  324. int m_idx, h_idx;
  325. m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row);
  326. // mLSE shape shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord"
  327. mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY;
  328. }
  329. }
  330. }
  331. if constexpr (!Clear_O) { return; }
  332. GmemTiledCopyO gmem_tiled_copy_O;
  333. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  334. Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})));
  335. if constexpr (!PackGQA) {
  336. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
  337. #pragma unroll
  338. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
  339. Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  340. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  341. Tensor tOrO = make_fragment_like(tOgO);
  342. cute::clear(tOrO);
  343. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  344. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  345. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
  346. );
  347. } else {
  348. // If PackGQA, we split the work of compute O_ptr among threads in the same row
  349. using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>;
  350. Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO)));
  351. cute::clear(tOrO);
  352. PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
  353. }
  354. }
  355. };
  356. } // namespace flash