epilogue_bwd_sm90_tma.hpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/barrier.h>
  7. #include "cute/tensor.hpp"
  8. #include "cutlass/gemm/collective/collective_builder.hpp"
  9. #include "named_barrier.hpp"
  10. #include "utils.h"
  11. namespace flash {
  12. using namespace cute;
  13. template <class TileShape_MNK_, class Element_, int NumEpilogueThreads_, bool Varlen_,
  14. bool dKV_swapAB_, int AtomLayoutKdKV=1>
  15. struct CollectiveEpilogueBwd {
  16. using TileShape_MNK = TileShape_MNK_;
  17. using Element = Element_;
  18. static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
  19. static constexpr bool Varlen = Varlen_;
  20. static constexpr bool dKV_swapAB = dKV_swapAB_;
  21. using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE;
  22. // These are for storing the output tensor without TMA (e.g., for setting output to zero)
  23. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  24. static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  25. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  26. static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);
  27. static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
  28. using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  29. Stride<Int<kGmemThreadsPerRow>, _1>>;
  30. using GmemTiledCopydKV = decltype(
  31. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  32. GmemLayoutAtom{},
  33. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
  34. using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  35. // TODO: do we have to change this if dKV_swapAB is true?
  36. decltype(cute::get<1>(TileShape_MNK{})), Int<CUTE_STATIC_V(cute::get<2>(TileShape_MNK{})) / AtomLayoutKdKV>>());
  37. using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{})));
  38. using SmemLayoutdKVtTMA =
  39. decltype(cute::composition(SmemLayoutdKVTMA{},
  40. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  41. make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
  42. // If we don't use TMA
  43. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16);
  44. static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
  45. using SmemLayoutAtomdKVSTG =
  46. decltype(composition(Swizzle<kSwizzle, 3, 3>{},
  47. Layout<Shape<Int<8>, Int<kBlockKSmem>>,
  48. Stride<Int<kBlockKSmem>, _1>>{}));
  49. using SmemLayoutAtomdKV = std::conditional_t<!Varlen, SmemLayoutAtomdKVTMA, SmemLayoutAtomdKVSTG>;
  50. using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{})));
  51. using SmemLayoutdKVt =
  52. decltype(cute::composition(SmemLayoutdKV{},
  53. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  54. make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
  55. using SmemCopyAtomdKV = Copy_Atom<
  56. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  57. Element>;
  58. static constexpr size_t SmemAlignmentdKV = cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{});
  59. static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment");
  60. struct TensorStorage : cute::aligned_struct<SmemAlignmentdKV> {
  61. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dk;
  62. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dv;
  63. };
  64. using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
  65. using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
  66. using TMA_dKV = decltype(make_tma_copy(
  67. GmemTiledCopydKVTMA{},
  68. make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapedKV{}, StridedKV{}),
  69. SmemLayoutdKVTMA{},
  70. select<1, 2>(TileShape_MNK{}),
  71. _1{})); // no mcast for dKV
  72. // Host side kernel arguments
  73. struct Arguments {
  74. Element* ptr_dK;
  75. ShapedKV const shape_dK;
  76. StridedKV const stride_dK;
  77. Element* ptr_dV;
  78. StridedKV const stride_dV;
  79. int const num_heads_q;
  80. int* dk_semaphore;
  81. int* dv_semaphore;
  82. int const* cu_seqlens;
  83. int const* seqused;
  84. };
  85. // Device side kernel params
  86. struct Params {
  87. Element* ptr_dK;
  88. ShapedKV const shape_dK;
  89. StridedKV const stride_dK;
  90. Element* ptr_dV;
  91. StridedKV const stride_dV;
  92. TMA_dKV tma_store_dK, tma_store_dV;
  93. int const* cu_seqlens = nullptr;
  94. int const* seqused = nullptr;
  95. };
  96. static Params
  97. to_underlying_arguments(Arguments const& args) {
  98. if constexpr (Varlen) {
  99. assert (args.cu_seqlens != nullptr);
  100. }
  101. Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK);
  102. Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV);
  103. TMA_dKV tma_store_dK = make_tma_copy(
  104. GmemTiledCopydKVTMA{},
  105. mdK,
  106. SmemLayoutdKVTMA{},
  107. select<1, 2>(TileShape_MNK{}),
  108. _1{}); // no mcast for dKV
  109. TMA_dKV tma_store_dV = make_tma_copy(
  110. GmemTiledCopydKVTMA{},
  111. mdV,
  112. SmemLayoutdKVTMA{},
  113. select<1, 2>(TileShape_MNK{}),
  114. _1{}); // no mcast for dKV
  115. return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV,
  116. tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
  117. }
  118. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  119. CUTLASS_DEVICE
  120. static void prefetch_tma_descriptors(Params const& params) {
  121. if constexpr (!Varlen) {
  122. cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor());
  123. cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor());
  124. }
  125. }
  126. template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
  127. CUTLASS_DEVICE void
  128. store(Params const& params,
  129. FrgTensorO const& tdKrdK,
  130. FrgTensorO const& tdVrdV,
  131. SharedStorage& shared_storage,
  132. TiledMma tiled_mma,
  133. int thread_idx,
  134. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  135. ) {
  136. auto [n_block, bidh, bidb] = block_coord;
  137. Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{}));
  138. Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{}));
  139. Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{}));
  140. Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{}));
  141. auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma);
  142. auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx);
  143. Tensor tdVrdV_out = flash::convert_type<Element>(tdVrdV);
  144. Tensor tdKrdK_out = flash::convert_type<Element>(tdKrdK);
  145. Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  146. Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  147. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); }
  148. Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  149. Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  150. // Make sure all WGs have finished reading K and V
  151. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  152. cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
  153. cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
  154. if constexpr (!Varlen) {
  155. cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
  156. cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  157. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  158. Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK);
  159. Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK);
  160. Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  161. Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  162. auto block_tma_dK = params.tma_store_dK.get_slice(_0{});
  163. auto block_tma_dV = params.tma_store_dV.get_slice(_0{});
  164. Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
  165. Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
  166. Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
  167. Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
  168. int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
  169. if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
  170. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  171. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  172. int const lane_predicate = cute::elect_one_sync();
  173. if (lane_predicate) {
  174. cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);
  175. cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);
  176. tma_store_arrive();
  177. }
  178. }
  179. tma_store_wait<0>();
  180. // // Tell warp 0 that smem_k and smem_v are ready
  181. // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::KVEmpty) /*id*/);
  182. } else {
  183. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  184. int const offset = !Varlen ? 0 : params.cu_seqlens[bidb];
  185. int const seqlen = !Varlen ? get<0>(params.shape_dK) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]);
  186. Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !Varlen ? bidb : 0);
  187. Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  188. Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !Varlen ? bidb : 0);
  189. Tensor gdV = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  190. GmemTiledCopydKV gmem_tiled_copy_dKV;
  191. auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
  192. Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
  193. Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
  194. Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
  195. Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)
  196. Tensor tdKVrdV = make_fragment_like(tdKVgdV);
  197. Tensor tdKVrdK = make_fragment_like(tdKVgdK);
  198. cute::copy(gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV);
  199. cute::copy(gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK);
  200. // Construct identity layout for gdKV
  201. Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  202. // Repeat the partitioning with identity layouts
  203. Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
  204. Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
  205. #pragma unroll
  206. for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
  207. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  208. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  209. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  210. gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen - n_block * kBlockN
  211. );
  212. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  213. gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen - n_block * kBlockN
  214. );
  215. }
  216. }
  217. CUTLASS_DEVICE void
  218. store_tail() {
  219. // if constexpr (!Varlen) { tma_store_wait<0>(); }
  220. }
  221. // Write 0 to dK and dV
  222. CUTLASS_DEVICE void
  223. store_zero(
  224. Params const& params,
  225. int thread_idx,
  226. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  227. ) {
  228. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  229. auto [n_block, bidh, bidb] = block_coord;
  230. int const offset = !Varlen ? 0 : params.cu_seqlens[bidb];
  231. int const seqlen = !Varlen ? get<0>(params.shape_dK) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset);
  232. Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !Varlen ? bidb : 0);
  233. Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  234. Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !Varlen ? bidb : 0);
  235. Tensor gdV = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  236. GmemTiledCopydKV gmem_tiled_copy_dKV;
  237. auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
  238. Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
  239. Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
  240. Tensor tdKVrdKV = make_fragment_like(tdKVgdK);
  241. clear(tdKVrdKV);
  242. // Construct identity layout for gdKV
  243. Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  244. // Repeat the partitioning with identity layouts
  245. Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
  246. Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
  247. #pragma unroll
  248. for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
  249. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  250. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  251. gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen - n_block * kBlockN
  252. );
  253. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  254. gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen - n_block * kBlockN
  255. );
  256. }
  257. };
  258. template <class TileShape_MNK_, class ElementAccum, int NumEpilogueThreads_, bool Varlen_, bool Deterministic>
  259. struct CollectiveEpilogueBwdGQA {
  260. using TileShape_MNK = TileShape_MNK_;
  261. using Element = ElementAccum;
  262. static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
  263. static constexpr bool Varlen = Varlen_;
  264. using GmemTiledCopydKVTMA = cute::SM90_TMA_REDUCE_ADD;
  265. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  266. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  267. using SmemLayoutAtomdKVaccumTMA =
  268. decltype(composition(Swizzle<0, 4, 3>{}, // We don't want any swizzle
  269. Layout<Shape<Int<8>, Int<kHeadDim>>,
  270. Stride<Int<kHeadDim>, _1>>{}));
  271. using SmemLayoutdKVaccumTMA = decltype(tile_to_shape(SmemLayoutAtomdKVaccumTMA{}, select<1, 2>(TileShape_MNK{})));
  272. // Thread layout, 256 threads per row
  273. using R2SLayoutAtomdKVaccum = Layout<Shape<Int<NumEpilogueThreads>>, Stride<_1>>;
  274. using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdKVaccum{},
  275. Layout<Shape < _4>>{})); // Val layout, 4 vals per store
  276. using SmemLayoutdKVaccum = Layout<Shape<Int<kBlockN * kHeadDim>>, Stride<_1>>;
  277. static_assert(size(SmemLayoutdKVaccumTMA{}) == size(SmemLayoutdKVaccum{}), "SmemLayoutdKVaccumTMA and SmemLayoutdKVaccum must have the same size");
  278. struct TensorStorage : cute::aligned_struct<128> {
  279. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdKVaccumTMA>> smem_dkv;
  280. };
  281. using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
  282. using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
  283. using TMA_add_dKV = decltype(make_tma_copy(
  284. GmemTiledCopydKVTMA{},
  285. make_tensor(make_gmem_ptr(static_cast<ElementAccum*>(nullptr)), ShapedKV{}, StridedKV{}),
  286. SmemLayoutdKVaccumTMA{},
  287. select<1, 2>(TileShape_MNK{}),
  288. _1{})); // no mcast for dKV
  289. // Host side kernel arguments
  290. struct Arguments {
  291. ElementAccum* ptr_dKaccum;
  292. ShapedKV const shape_dKaccum;
  293. StridedKV const stride_dKaccum;
  294. ElementAccum* ptr_dVaccum;
  295. StridedKV const stride_dVaccum;
  296. int num_heads_q;
  297. int* dk_semaphore;
  298. int* dv_semaphore;
  299. int const* cu_seqlens;
  300. int const* seqused;
  301. };
  302. // Device side kernel params
  303. struct Params {
  304. ShapedKV const shape_dKaccum;
  305. TMA_add_dKV tma_add_dK, tma_add_dV;
  306. cutlass::FastDivmod qhead_per_khead_divmod;
  307. int* dk_semaphore;
  308. int* dv_semaphore;
  309. int const* cu_seqlens = nullptr;
  310. int const* seqused = nullptr;
  311. };
  312. static Params
  313. to_underlying_arguments(Arguments const& args) {
  314. if constexpr (Varlen) {
  315. assert (args.cu_seqlens != nullptr);
  316. }
  317. Tensor mdKaccum = make_tensor(make_gmem_ptr(args.ptr_dKaccum), args.shape_dKaccum, args.stride_dKaccum);
  318. Tensor mdVaccum = make_tensor(make_gmem_ptr(args.ptr_dVaccum), args.shape_dKaccum, args.stride_dVaccum);
  319. TMA_add_dKV tma_add_dK = make_tma_copy(
  320. GmemTiledCopydKVTMA{},
  321. mdKaccum,
  322. SmemLayoutdKVaccumTMA{},
  323. select<1, 2>(TileShape_MNK{}),
  324. _1{}); // no mcast for dKV
  325. TMA_add_dKV tma_add_dV = make_tma_copy(
  326. GmemTiledCopydKVTMA{},
  327. mdVaccum,
  328. SmemLayoutdKVaccumTMA{},
  329. select<1, 2>(TileShape_MNK{}),
  330. _1{}); // no mcast for dKV
  331. if constexpr (Deterministic) {
  332. assert(args.dk_semaphore != nullptr);
  333. assert(args.dv_semaphore != nullptr);
  334. }
  335. if constexpr (Varlen) {
  336. assert(args.cu_seqlens != nullptr);
  337. }
  338. return {args.shape_dKaccum, tma_add_dK, tma_add_dV,
  339. cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<2>(args.shape_dKaccum))),
  340. args.dk_semaphore, args.dv_semaphore,
  341. args.cu_seqlens, args.seqused};
  342. }
  343. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  344. CUTLASS_DEVICE
  345. static void prefetch_tma_descriptors(Params const& params) {
  346. cute::prefetch_tma_descriptor(params.tma_add_dK.get_tma_descriptor());
  347. cute::prefetch_tma_descriptor(params.tma_add_dV.get_tma_descriptor());
  348. }
  349. template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
  350. CUTLASS_DEVICE void
  351. store(Params const& params,
  352. FrgTensorO const& tdKrdK,
  353. FrgTensorO const& tdVrdV,
  354. SharedStorage& shared_storage,
  355. TiledMma tiled_mma,
  356. int thread_idx,
  357. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  358. ) {
  359. auto [n_block, bidh, bidb] = block_coord;
  360. int bidh_idx_in_group;
  361. // int bidh_kv = params.qhead_per_khead_divmod.divide(bidh);
  362. int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh);
  363. Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{});
  364. Tensor sdKVTMA = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumTMA{});
  365. int const offset_padded = !Varlen ? 0 : (params.cu_seqlens[bidb] + bidb * kBlockN) / kBlockN * kBlockN;
  366. Tensor mdKaccum = params.tma_add_dK.get_tma_tensor(params.shape_dKaccum)(_, _, bidh_kv, !Varlen ? bidb : 0);
  367. Tensor mdVaccum = params.tma_add_dV.get_tma_tensor(params.shape_dKaccum)(_, _, bidh_kv, !Varlen ? bidb : 0);
  368. Tensor gdKaccum = local_tile(domain_offset(make_coord(offset_padded, _0{}), mdKaccum), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  369. Tensor gdVaccum = local_tile(domain_offset(make_coord(offset_padded, _0{}), mdVaccum), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  370. auto block_tma_dK = params.tma_add_dK.get_slice(_0{});
  371. auto block_tma_dV = params.tma_add_dV.get_slice(_0{});
  372. Tensor tdKgdK = block_tma_dK.partition_D(gdKaccum); // (TMA, TMA_M, TMA_K)
  373. Tensor tdKsdK = block_tma_dK.partition_S(sdKVTMA); // (TMA, TMA_M, TMA_K)
  374. Tensor tdVgdV = block_tma_dV.partition_D(gdVaccum); // (TMA, TMA_M, TMA_K)
  375. Tensor tdVsdV = block_tma_dV.partition_S(sdKVTMA); // (TMA, TMA_M, TMA_K)
  376. R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum;
  377. auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
  378. Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV);
  379. Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N)
  380. cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum);
  381. // int const num_batch = params.num_batch;
  382. int const num_batch = get<3>(params.shape_dKaccum);
  383. int const num_head_kv = get<2>(params.shape_dKaccum);
  384. int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv;
  385. using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
  386. // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
  387. if constexpr (Deterministic) {
  388. Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
  389. }
  390. // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);}
  391. cutlass::arch::fence_view_async_shared();
  392. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  393. if (thread_idx == 0) {
  394. cute::copy(params.tma_add_dV, tdVsdV, tdVgdV);
  395. tma_store_arrive();
  396. }
  397. tma_store_wait<0>();
  398. if constexpr (Deterministic) {
  399. Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
  400. }
  401. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  402. Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N)
  403. cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum);
  404. lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv;
  405. // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
  406. if constexpr (Deterministic) {
  407. Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
  408. }
  409. // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);}
  410. cutlass::arch::fence_view_async_shared();
  411. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  412. if (thread_idx == 0) {
  413. cute::copy(params.tma_add_dK, tdKsdK, tdKgdK);
  414. tma_store_arrive();
  415. }
  416. tma_store_wait<0>();
  417. if constexpr (Deterministic) {
  418. Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
  419. }
  420. }
  421. CUTLASS_DEVICE void
  422. store_tail() {
  423. }
  424. // Write 0 to dK and dV
  425. CUTLASS_DEVICE void
  426. store_zero(
  427. Params const& params,
  428. int thread_idx,
  429. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  430. ) {
  431. // Don't need to do anything since dKaccum and dVaccum are already zero-initialized
  432. }
  433. };
  434. } // namespace flash