epilogue_fwd.hpp 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/fast_math.h> // For FastDivMod
  7. #include "cute/tensor.hpp"
  8. #include "cutlass/gemm/collective/builders/sm90_common.inl"
  9. #include "cutlass/epilogue/collective/builders/sm90_common.inl"
  10. #include "seqlen.h"
  11. #include "named_barrier.hpp"
  12. #include "pack_gqa.h"
  13. #include "utils.h"
  14. namespace flash {
  15. using namespace cute;
  16. template <class TileShape_MNK_, class ClusterShape_, class Element_, class ArchTag_,
  17. int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool FP8PermuteCol=false>
  18. struct CollectiveEpilogueFwd {
  19. using TileShape_MNK = TileShape_MNK_;
  20. using ClusterShape = ClusterShape_;
  21. using Element = Element_;
  22. using ArchTag = ArchTag_;
  23. static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
  24. static constexpr bool Varlen = Varlen_;
  25. static constexpr bool PackGQA = PackGQA_;
  26. static constexpr bool Use_smem = sizeof(Element) <= 2;
  27. static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && Use_smem && !PackGQA;
  28. static_assert(ArchTag::kMinComputeCapability >= 80);
  29. static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);
  30. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  31. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  32. using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
  33. // These are for storing the output tensor without TMA (e.g., for setting output to zero)
  34. static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);
  35. static_assert(kHeadDim % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
  36. // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements
  37. // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times
  38. // we need to call divmod.
  39. static constexpr int kBytePerRow = kHeadDim * sizeof(Element);
  40. static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
  41. // static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
  42. // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerStore, NumEpilogueThreads);
  43. static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore;
  44. // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp
  45. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);
  46. static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
  47. using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  48. Stride<Int<kGmemThreadsPerRow>, _1>>;
  49. static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow");
  50. using GmemTiledCopyO = decltype(
  51. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  52. GmemLayoutAtom{},
  53. Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); // Val layout, 8 or 16 vals per store
  54. using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  55. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  56. using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 2>(TileShape_MNK{})));
  57. static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
  58. static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
  59. using SmemLayoutAtomO = decltype(
  60. composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
  61. Layout<Shape<_8, Int<kBlockKGmem>>,
  62. Stride<Int<kBlockKGmem>, _1>>{}));
  63. using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  64. using SmemLayoutO = std::conditional_t<ArchTag::kMinComputeCapability >= 90, SmemLayoutOTMA, SmemLayoutOSTS>;
  65. using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch, num_splits)
  66. using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
  67. using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits)
  68. // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
  69. using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
  70. using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;
  71. // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
  72. using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
  73. using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;
  74. using CopyOpR2S = std::conditional_t<
  75. ArchTag::kMinComputeCapability >= 90,
  76. // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
  77. decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>()),
  78. AutoVectorizingCopyWithAssumedAlignment<128>
  79. >;
  80. using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
  81. // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{});
  82. // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment");
  83. // struct TensorStorage : cute::aligned_struct<SmemAlignmentO> {
  84. // cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0, SmemAlignmentO> smem_o;
  85. // };
  86. struct TensorStorage : cute::aligned_struct<128> {
  87. cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0> smem_o;
  88. };
  89. using TMA_O = std::conditional_t<
  90. Use_TMA_O,
  91. decltype(make_tma_copy(
  92. GmemTiledCopyOTMA{},
  93. make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
  94. SmemLayoutOTMA{},
  95. select<0, 2>(TileShape_MNK{}),
  96. _1{})), // no mcast for O
  97. std::nullptr_t
  98. >;
  99. // Host side kernel arguments
  100. struct Arguments {
  101. Element* ptr_O;
  102. ShapeO const shape_O;
  103. StrideO const stride_O;
  104. float* ptr_LSE;
  105. StrideLSE const stride_LSE;
  106. int32_t const nheads_kv;
  107. int const* cu_seqlens = nullptr;
  108. int const* seqused = nullptr;
  109. };
  110. // Device side kernel params
  111. struct Params {
  112. Element* ptr_O;
  113. ShapeO const shape_O;
  114. StrideO const stride_O;
  115. ShapeOPacked const shape_O_packed;
  116. StrideOPacked const stride_O_packed;
  117. float* ptr_LSE;
  118. StrideLSE const stride_LSE;
  119. ShapeLSEPacked const shape_LSE_packed;
  120. StrideLSEPacked const stride_LSE_packed;
  121. cutlass::FastDivmod qhead_per_khead_divmod;
  122. TMA_O tma_store_O;
  123. int const* cu_seqlens = nullptr;
  124. int const* seqused = nullptr;
  125. };
  126. static Params
  127. to_underlying_arguments(Arguments const& args) {
  128. Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
  129. TMA_O tma_store_O = [&]{
  130. if constexpr (Use_TMA_O) {
  131. return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 2>(TileShape_MNK{}), _1{}); // no mcast
  132. } else {
  133. return nullptr;
  134. }
  135. }();
  136. // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)
  137. int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);
  138. auto const shape_O_packed = cute::conditional_return<!PackGQA>(
  139. args.shape_O,
  140. make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
  141. );
  142. auto const stride_O_packed = cute::conditional_return<!PackGQA>(
  143. args.stride_O,
  144. make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O))
  145. );
  146. // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)
  147. auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(
  148. select<0, 2, 3, 4>(args.shape_O),
  149. make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
  150. );
  151. auto const stride_LSE_packed = cute::conditional_return<!PackGQA>(
  152. args.stride_LSE,
  153. make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE))
  154. );
  155. return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed,
  156. args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed,
  157. cutlass::FastDivmod(qhead_per_khead),
  158. tma_store_O, args.cu_seqlens, args.seqused};
  159. }
  160. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  161. CUTLASS_DEVICE
  162. static void prefetch_tma_descriptors(Params const& params) {
  163. if constexpr (Use_TMA_O) {
  164. cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor());
  165. }
  166. }
  167. template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
  168. CUTLASS_DEVICE void
  169. store(Params const& params,
  170. FrgTensorO const& tOrO,
  171. FrgTensorLSE const& lse,
  172. SharedStorage& shared_storage,
  173. TiledMma tiled_mma,
  174. int thread_idx,
  175. cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
  176. ) {
  177. auto [m_block, bidh, bidb, split_idx] = block_coord;
  178. Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});
  179. // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO);
  180. Tensor tOrO_out = make_tensor_like<Element>(tOrO);
  181. flash::convert_type_out(tOrO, tOrO_out);
  182. if constexpr (FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4)) { flash::permute_output_fp8_Vcolmajor(tOrO_out); }
  183. // Make sure all WGs have finished reading V
  184. // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that
  185. // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with
  186. // cp.async if we need).
  187. flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  188. // Step 1: Write O from rmem -> smem
  189. if constexpr (Use_smem) {
  190. auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
  191. auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
  192. Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  193. Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  194. // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  195. cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
  196. if constexpr (Use_TMA_O) {
  197. cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
  198. cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  199. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  200. } else {
  201. flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  202. }
  203. } else {
  204. if constexpr (ArchTag::kMinComputeCapability >= 90) {
  205. #pragma unroll
  206. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  207. shared_storage.pipelines.barrier_O.arrive(cta_id);
  208. }
  209. }
  210. }
  211. flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
  212. bool is_varlen = Varlen && params.cu_seqlens;
  213. int offset_o = seqlen_info.offset;
  214. int seqlen_o = seqlen_info.seqlen;
  215. // Step 2: Write LSE from rmem -> gmem
  216. auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
  217. // (MMA,MMA_M,MMA_K)
  218. Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})));
  219. static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
  220. static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
  221. Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));
  222. Tensor taccOcO_row = taccOcO_rowcol(_, _0{});
  223. CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
  224. using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>;
  225. Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx);
  226. // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
  227. if constexpr (!PackGQA) {
  228. #pragma unroll
  229. for (int mi = 0; mi < size(lse); ++mi) {
  230. int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));
  231. if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); }
  232. }
  233. } else {
  234. PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
  235. }
  236. // Step 3: Write O from smem -> gmem
  237. if constexpr (Use_TMA_O) {
  238. Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
  239. Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  240. auto block_tma_O = params.tma_store_O.get_slice(_0{});
  241. Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
  242. Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
  243. int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
  244. if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
  245. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  246. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  247. if (cute::elect_one_sync()) {
  248. cute::copy(params.tma_store_O, tOsO, tOgO);
  249. tma_store_arrive();
  250. tma_store_wait<0>();
  251. #pragma unroll
  252. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  253. shared_storage.pipelines.barrier_O.arrive(cta_id);
  254. }
  255. }
  256. }
  257. } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence
  258. Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
  259. Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  260. // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }
  261. if constexpr (Use_smem) {
  262. GmemTiledCopyO gmem_tiled_copy_O;
  263. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  264. Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  265. // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  266. Tensor tOrO = make_fragment_like(tOsO);
  267. cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
  268. if constexpr (ArchTag::kMinComputeCapability >= 90) {
  269. cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v
  270. #pragma unroll
  271. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  272. shared_storage.pipelines.barrier_O.arrive(cta_id);
  273. }
  274. }
  275. if constexpr (!PackGQA) {
  276. // (BLK_M,BLK_K) -> (blk_m,blk_k)
  277. Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})));
  278. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO)));
  279. #pragma unroll
  280. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
  281. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  282. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  283. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  284. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
  285. );
  286. } else {
  287. // If PackGQA, we split the work of compute O_ptr among threads in the same row
  288. PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
  289. }
  290. } else {
  291. // We already arrived on barrier_O earlier
  292. if constexpr (!PackGQA) {
  293. static constexpr int kGmemElemsPerStoreDirect = 2;
  294. cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element> gmem_copy_direct;
  295. // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  296. Tensor tOrO_rowcol = make_tensor(tOrO_out.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));
  297. Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
  298. Tensor tOgO = thread_mma.partition_C(gO);
  299. Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout()));
  300. Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
  301. Tensor taccOcO_col = taccOcO_rowcol(_0{}, _);
  302. #pragma unroll
  303. for (int m = 0; m < size(taccOcO_row); ++m) {
  304. if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) {
  305. #pragma unroll
  306. for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) {
  307. if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) {
  308. cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k));
  309. }
  310. }
  311. }
  312. }
  313. } else {
  314. PackGQAt::store_O_direct(mO, tOrO_out, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
  315. }
  316. }
  317. }
  318. }
  319. CUTLASS_DEVICE void
  320. store_tail() {
  321. // Don't need to do tma_store_wait<0>() here since we already did in @store
  322. }
  323. // Write 0 to output and -inf to LSE
  324. template <bool Clear_O=true>
  325. CUTLASS_DEVICE void
  326. store_zero(
  327. Params const& params,
  328. int thread_idx,
  329. cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
  330. ) {
  331. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  332. auto [m_block, bidh, bidb, split_idx] = block_coord;
  333. flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
  334. bool const is_varlen = Varlen && params.cu_seqlens;
  335. int offset_o = seqlen_info.offset;
  336. int seqlen_o = seqlen_info.seqlen;
  337. int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
  338. Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
  339. Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx);
  340. Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block));
  341. static_assert(kBlockM <= NumEpilogueThreads);
  342. if (thread_idx < kBlockM) {
  343. const int row = m_block * kBlockM + thread_idx;
  344. if constexpr (!PackGQA) {
  345. if (row < seqlen_o) { mLSE(row) = -INFINITY; }
  346. } else {
  347. if (row < seqlen_o * qhead_per_khead) {
  348. int m_idx, h_idx;
  349. m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row);
  350. // mLSE shape shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord"
  351. mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY;
  352. }
  353. }
  354. }
  355. if constexpr (!Clear_O) { return; }
  356. GmemTiledCopyO gmem_tiled_copy_O;
  357. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  358. Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})));
  359. if constexpr (!PackGQA) {
  360. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
  361. #pragma unroll
  362. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
  363. Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  364. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  365. Tensor tOrO = make_fragment_like(tOgO);
  366. cute::clear(tOrO);
  367. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  368. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  369. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
  370. );
  371. } else {
  372. // If PackGQA, we split the work of compute O_ptr among threads in the same row
  373. using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>;
  374. Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO)));
  375. cute::clear(tOrO);
  376. PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
  377. }
  378. }
  379. };
  380. } // namespace flash