epilogue_bwd_sm90_tma.hpp 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/barrier.h>
  7. #include "cute/tensor.hpp"
  8. #include "cutlass/gemm/collective/builders/sm90_common.inl"
  9. #include "seqlen.h"
  10. #include "named_barrier.hpp"
  11. #include "utils.h"
  12. namespace flash {
  13. using namespace cute;
  14. template <class TileShape_MNK_, class Element_, int NumEpilogueThreads_, bool Varlen_,
  15. bool dKV_swapAB_, int AtomLayoutKdKV=1>
  16. struct CollectiveEpilogueBwd {
  17. using TileShape_MNK = TileShape_MNK_;
  18. using Element = Element_;
  19. static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
  20. static constexpr bool Varlen = Varlen_;
  21. static constexpr bool dKV_swapAB = dKV_swapAB_;
  22. using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE;
  23. // These are for storing the output tensor without TMA (e.g., for setting output to zero)
  24. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  25. static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  26. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  27. static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);
  28. static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
  29. using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  30. Stride<Int<kGmemThreadsPerRow>, _1>>;
  31. using GmemTiledCopydKV = decltype(
  32. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  33. GmemLayoutAtom{},
  34. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
  35. using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  36. // TODO: do we have to change this if dKV_swapAB is true?
  37. decltype(cute::get<1>(TileShape_MNK{})), Int<CUTE_STATIC_V(cute::get<2>(TileShape_MNK{})) / AtomLayoutKdKV>>());
  38. using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{})));
  39. using SmemLayoutdKVtTMA =
  40. decltype(cute::composition(SmemLayoutdKVTMA{},
  41. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  42. make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
  43. // If we don't use TMA
  44. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16);
  45. static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
  46. using SmemLayoutAtomdKVSTG =
  47. decltype(composition(Swizzle<kSwizzle, 3, 3>{},
  48. Layout<Shape<Int<8>, Int<kBlockKSmem>>,
  49. Stride<Int<kBlockKSmem>, _1>>{}));
  50. using SmemLayoutAtomdKV = std::conditional_t<!Varlen, SmemLayoutAtomdKVTMA, SmemLayoutAtomdKVSTG>;
  51. using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{})));
  52. using SmemLayoutdKVt =
  53. decltype(cute::composition(SmemLayoutdKV{},
  54. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  55. make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
  56. using SmemCopyAtomdKV = Copy_Atom<
  57. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  58. Element>;
  59. static constexpr size_t SmemAlignmentdKV = cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{});
  60. static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment");
  61. struct TensorStorage : cute::aligned_struct<SmemAlignmentdKV> {
  62. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dk;
  63. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dv;
  64. };
  65. using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_k, d, head, batch)
  66. using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
  67. using TMA_dKV = decltype(make_tma_copy(
  68. GmemTiledCopydKVTMA{},
  69. make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapedKV{}, StridedKV{}),
  70. SmemLayoutdKVTMA{},
  71. select<1, 2>(TileShape_MNK{}),
  72. _1{})); // no mcast for dKV
  73. // Host side kernel arguments
  74. struct Arguments {
  75. Element* ptr_dK;
  76. ShapedKV const shape_dK;
  77. StridedKV const stride_dK;
  78. Element* ptr_dV;
  79. StridedKV const stride_dV;
  80. int const num_heads_q;
  81. int* dk_semaphore;
  82. int* dv_semaphore;
  83. int const* cu_seqlens;
  84. int const* seqused;
  85. };
  86. // Device side kernel params
  87. struct Params {
  88. Element* ptr_dK;
  89. ShapedKV const shape_dK;
  90. StridedKV const stride_dK;
  91. Element* ptr_dV;
  92. StridedKV const stride_dV;
  93. TMA_dKV tma_store_dK, tma_store_dV;
  94. int const* cu_seqlens = nullptr;
  95. int const* seqused = nullptr;
  96. };
  97. static Params
  98. to_underlying_arguments(Arguments const& args) {
  99. Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK);
  100. Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV);
  101. TMA_dKV tma_store_dK = make_tma_copy(
  102. GmemTiledCopydKVTMA{},
  103. mdK,
  104. SmemLayoutdKVTMA{},
  105. select<1, 2>(TileShape_MNK{}),
  106. _1{}); // no mcast for dKV
  107. TMA_dKV tma_store_dV = make_tma_copy(
  108. GmemTiledCopydKVTMA{},
  109. mdV,
  110. SmemLayoutdKVTMA{},
  111. select<1, 2>(TileShape_MNK{}),
  112. _1{}); // no mcast for dKV
  113. return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV,
  114. tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
  115. }
  116. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  117. CUTLASS_DEVICE
  118. static void prefetch_tma_descriptors(Params const& params) {
  119. if constexpr (!Varlen) {
  120. cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor());
  121. cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor());
  122. }
  123. }
  124. template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
  125. CUTLASS_DEVICE void
  126. store(Params const& params,
  127. FrgTensorO const& tdKrdK,
  128. FrgTensorO const& tdVrdV,
  129. SharedStorage& shared_storage,
  130. TiledMma tiled_mma,
  131. int thread_idx,
  132. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  133. ) {
  134. auto [n_block, bidh, bidb] = block_coord;
  135. Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{}));
  136. Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{}));
  137. Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{}));
  138. Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{}));
  139. auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma);
  140. auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx);
  141. Tensor tdVrdV_out = make_tensor_like<Element>(tdVrdV);
  142. flash::convert_type_out(tdVrdV, tdVrdV_out);
  143. Tensor tdKrdK_out = make_tensor_like<Element>(tdKrdK);
  144. flash::convert_type_out(tdKrdK, tdKrdK_out);
  145. Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  146. Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  147. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); }
  148. Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  149. Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  150. // Make sure all WGs have finished reading K and V
  151. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  152. cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
  153. cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
  154. if constexpr (!Varlen) {
  155. cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
  156. cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  157. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  158. Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK);
  159. Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK);
  160. Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  161. Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  162. auto block_tma_dK = params.tma_store_dK.get_slice(_0{});
  163. auto block_tma_dV = params.tma_store_dV.get_slice(_0{});
  164. Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
  165. Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
  166. Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
  167. Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
  168. int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
  169. if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
  170. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  171. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  172. if (cute::elect_one_sync()) {
  173. cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);
  174. cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);
  175. tma_store_arrive();
  176. }
  177. }
  178. tma_store_wait<0>();
  179. // // Tell warp 0 that smem_k and smem_v are ready
  180. // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::KVEmpty) /*id*/);
  181. } else {
  182. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  183. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  184. flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
  185. bool const is_varlen = Varlen && params.cu_seqlens;
  186. Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
  187. Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  188. Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
  189. Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  190. GmemTiledCopydKV gmem_tiled_copy_dKV;
  191. auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
  192. Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
  193. Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
  194. Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
  195. Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)
  196. Tensor tdKVrdV = make_fragment_like(tdKVgdV);
  197. Tensor tdKVrdK = make_fragment_like(tdKVgdK);
  198. cute::copy(gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV);
  199. cute::copy(gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK);
  200. // Construct identity layout for gdKV
  201. Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  202. // Repeat the partitioning with identity layouts
  203. Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
  204. Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
  205. #pragma unroll
  206. for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
  207. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  208. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  209. gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN
  210. );
  211. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  212. gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN
  213. );
  214. }
  215. }
  216. CUTLASS_DEVICE void
  217. store_tail() {
  218. // if constexpr (!Varlen) { tma_store_wait<0>(); }
  219. }
  220. // Write 0 to dK and dV
  221. CUTLASS_DEVICE void
  222. store_zero(
  223. Params const& params,
  224. int thread_idx,
  225. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  226. ) {
  227. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  228. auto [n_block, bidh, bidb] = block_coord;
  229. flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
  230. bool const is_varlen = Varlen && params.cu_seqlens;
  231. Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
  232. Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  233. Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
  234. Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  235. GmemTiledCopydKV gmem_tiled_copy_dKV;
  236. auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
  237. Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
  238. Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
  239. Tensor tdKVrdKV = make_fragment_like(tdKVgdK);
  240. clear(tdKVrdKV);
  241. // Construct identity layout for gdKV
  242. Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  243. // Repeat the partitioning with identity layouts
  244. Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
  245. Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
  246. #pragma unroll
  247. for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
  248. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  249. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  250. gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN
  251. );
  252. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  253. gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN
  254. );
  255. }
  256. };
  257. template <class TileShape_MNK_, class ElementAccum, int NumEpilogueThreads_, bool Varlen_, bool Deterministic>
  258. struct CollectiveEpilogueBwdGQA {
  259. using TileShape_MNK = TileShape_MNK_;
  260. using Element = ElementAccum;
  261. static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
  262. static constexpr bool Varlen = Varlen_;
  263. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  264. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  265. static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp");
  266. static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup;
  267. // Thread layout, 256 or 384 threads per row
  268. // We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ
  269. using R2SLayoutAtomdKVaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumWarpGroups>>>;
  270. using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdKVaccum{},
  271. Layout<Shape < _4>>{})); // Val layout, 4 vals per store
  272. using SmemLayoutdKVaccum = Layout<Shape<Int<kBlockN * kHeadDim / NumWarpGroups>, Int<NumWarpGroups>>>;
  273. using SmemLayoutdKVaccumFlat = Layout<Shape<Int<kBlockN * kHeadDim>>>;
  274. // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we
  275. // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue.
  276. static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256);
  277. struct TensorStorage : cute::aligned_struct<SmemAlignment> {
  278. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdKVaccum>, SmemAlignment> smem_dkv;
  279. };
  280. using ShapedKV = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_k_rounded * d, head, batch)
  281. using StridedKV = cute::Stride<_1, int64_t, int64_t>;
  282. // Host side kernel arguments
  283. struct Arguments {
  284. ElementAccum* ptr_dKaccum;
  285. ShapedKV const shape_dKaccum;
  286. StridedKV const stride_dKaccum;
  287. ElementAccum* ptr_dVaccum;
  288. StridedKV const stride_dVaccum;
  289. int num_heads_q;
  290. int* dk_semaphore;
  291. int* dv_semaphore;
  292. int const* cu_seqlens;
  293. int const* seqused;
  294. };
  295. // Device side kernel params
  296. struct Params {
  297. ElementAccum* ptr_dKaccum;
  298. ShapedKV const shape_dKaccum;
  299. StridedKV const stride_dKaccum;
  300. ElementAccum* ptr_dVaccum;
  301. StridedKV const stride_dVaccum;
  302. cutlass::FastDivmod qhead_per_khead_divmod;
  303. int* dk_semaphore;
  304. int* dv_semaphore;
  305. int const* cu_seqlens = nullptr;
  306. int const* seqused = nullptr;
  307. };
  308. static Params
  309. to_underlying_arguments(Arguments const& args) {
  310. if constexpr (Deterministic) {
  311. assert(args.dk_semaphore != nullptr);
  312. assert(args.dv_semaphore != nullptr);
  313. }
  314. return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.stride_dVaccum,
  315. cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))),
  316. args.dk_semaphore, args.dv_semaphore,
  317. args.cu_seqlens, args.seqused};
  318. }
  319. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  320. CUTLASS_DEVICE
  321. static void prefetch_tma_descriptors(Params const& params) {
  322. }
  323. template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
  324. CUTLASS_DEVICE void
  325. store(Params const& params,
  326. FrgTensorO const& tdKrdK,
  327. FrgTensorO const& tdVrdV,
  328. SharedStorage& shared_storage,
  329. TiledMma tiled_mma,
  330. int thread_idx,
  331. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  332. ) {
  333. auto [n_block, bidh, bidb] = block_coord;
  334. int bidh_idx_in_group;
  335. int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh);
  336. Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{});
  337. Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{});
  338. static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum);
  339. flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused};
  340. bool const is_varlen = Varlen && params.cu_seqlens;
  341. Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
  342. Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dKaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
  343. Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
  344. Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
  345. R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum;
  346. auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
  347. Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV);
  348. // Make sure all WGs have finished reading K and V, otherwise we get racy dQ
  349. // because smem_q could be changed.
  350. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  351. Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N)
  352. cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum);
  353. // int const num_batch = params.num_batch;
  354. int const num_batch = get<2>(params.shape_dKaccum);
  355. int const num_head_kv = get<1>(params.shape_dKaccum);
  356. int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv;
  357. using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
  358. // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
  359. if constexpr (Deterministic) {
  360. Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
  361. }
  362. // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);}
  363. cutlass::arch::fence_view_async_shared();
  364. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  365. if (thread_idx == 0) {
  366. SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
  367. tma_store_arrive();
  368. tma_store_wait<0>();
  369. }
  370. if constexpr (Deterministic) {
  371. Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
  372. }
  373. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  374. Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N)
  375. cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum);
  376. lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv;
  377. // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
  378. if constexpr (Deterministic) {
  379. Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
  380. }
  381. // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);}
  382. cutlass::arch::fence_view_async_shared();
  383. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  384. if (thread_idx == 0) {
  385. SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
  386. tma_store_arrive();
  387. tma_store_wait<0>();
  388. }
  389. if constexpr (Deterministic) {
  390. Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
  391. }
  392. }
  393. CUTLASS_DEVICE void
  394. store_tail() {
  395. }
  396. // Write 0 to dK and dV
  397. CUTLASS_DEVICE void
  398. store_zero(
  399. Params const& params,
  400. int thread_idx,
  401. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  402. ) {
  403. // Don't need to do anything since dKaccum and dVaccum are already zero-initialized
  404. }
  405. };
  406. } // namespace flash