epilogue_bwd.hpp 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/barrier.h>
  7. #include "cute/tensor.hpp"
  8. #include "cutlass/gemm/collective/builders/sm90_common.inl"
  9. #include "seqlen.h"
  10. #include "named_barrier.hpp"
  11. #include "utils.h"
  12. namespace flash {
  13. using namespace cute;
  14. template <class TileShape_MNK_, class Element_, class ArchTag_,
  15. int NumEpilogueThreads_, bool Varlen_, bool dKV_swapAB_, int AtomLayoutKdKV=1>
  16. struct CollectiveEpilogueBwd {
  17. using TileShape_MNK = TileShape_MNK_;
  18. using Element = Element_;
  19. using ArchTag = ArchTag_;
  20. static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
  21. static constexpr bool Varlen = Varlen_;
  22. static constexpr bool dKV_swapAB = dKV_swapAB_;
  23. static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90;
  24. static_assert(ArchTag::kMinComputeCapability >= 80);
  25. using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE;
  26. // These are for storing the output tensor without TMA (e.g., for setting output to zero)
  27. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  28. static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  29. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  30. static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);
  31. static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
  32. using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  33. Stride<Int<kGmemThreadsPerRow>, _1>>;
  34. using GmemTiledCopydKV = decltype(
  35. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  36. GmemLayoutAtom{},
  37. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
  38. using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  39. // TODO: do we have to change this if dKV_swapAB is true?
  40. decltype(cute::get<1>(TileShape_MNK{})), Int<CUTE_STATIC_V(cute::get<2>(TileShape_MNK{})) / AtomLayoutKdKV>>());
  41. using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{})));
  42. using SmemLayoutdKVtTMA =
  43. decltype(cute::composition(SmemLayoutdKVTMA{},
  44. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  45. make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
  46. // If we don't use TMA
  47. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16);
  48. static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
  49. using SmemLayoutAtomdKVSTG =
  50. decltype(composition(Swizzle<kSwizzle, 3, 3>{},
  51. Layout<Shape<Int<8>, Int<kBlockKSmem>>,
  52. Stride<Int<kBlockKSmem>, _1>>{}));
  53. using SmemLayoutAtomdKV = std::conditional_t<Use_TMA, SmemLayoutAtomdKVTMA, SmemLayoutAtomdKVSTG>;
  54. using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{})));
  55. using SmemLayoutdKVt =
  56. decltype(cute::composition(SmemLayoutdKV{},
  57. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  58. make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
  59. using SmemCopyAtomdKV = Copy_Atom<
  60. std::conditional_t<
  61. ArchTag::kMinComputeCapability >= 90,
  62. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  63. AutoVectorizingCopyWithAssumedAlignment<128>
  64. >,
  65. Element>;
  66. static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128;
  67. static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment");
  68. struct TensorStorage : cute::aligned_struct<SmemAlignmentdKV> {
  69. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dk;
  70. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dv;
  71. };
  72. using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_k, d, head, batch)
  73. using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
  74. using TMA_dKV = std::conditional_t<
  75. Use_TMA,
  76. decltype(make_tma_copy(
  77. GmemTiledCopydKVTMA{},
  78. make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapedKV{}, StridedKV{}),
  79. SmemLayoutdKVTMA{},
  80. select<1, 2>(TileShape_MNK{}),
  81. _1{})), // no mcast for dKV
  82. std::nullptr_t
  83. >;
  84. // Host side kernel arguments
  85. struct Arguments {
  86. Element* ptr_dK;
  87. ShapedKV const shape_dK;
  88. StridedKV const stride_dK;
  89. Element* ptr_dV;
  90. StridedKV const stride_dV;
  91. int const num_heads_q;
  92. int* dk_semaphore;
  93. int* dv_semaphore;
  94. int const* cu_seqlens;
  95. int const* seqused;
  96. };
  97. // Device side kernel params
  98. struct Params {
  99. Element* ptr_dK;
  100. ShapedKV const shape_dK;
  101. StridedKV const stride_dK;
  102. Element* ptr_dV;
  103. StridedKV const stride_dV;
  104. TMA_dKV tma_store_dK, tma_store_dV;
  105. int const* cu_seqlens = nullptr;
  106. int const* seqused = nullptr;
  107. };
  108. static Params
  109. to_underlying_arguments(Arguments const& args) {
  110. Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK);
  111. Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV);
  112. TMA_dKV tma_store_dK = [&] {
  113. if constexpr (Use_TMA) {
  114. return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
  115. } else {
  116. return nullptr;
  117. }
  118. }();
  119. TMA_dKV tma_store_dV = [&] {
  120. if constexpr (Use_TMA) {
  121. return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
  122. } else {
  123. return nullptr;
  124. }
  125. }();
  126. return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV,
  127. tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
  128. }
  129. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  130. CUTLASS_DEVICE
  131. static void prefetch_tma_descriptors(Params const& params) {
  132. if constexpr (Use_TMA) {
  133. cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor());
  134. cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor());
  135. }
  136. }
  137. template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
  138. CUTLASS_DEVICE void
  139. store(Params const& params,
  140. FrgTensorO const& tdKrdK,
  141. FrgTensorO const& tdVrdV,
  142. SharedStorage& shared_storage,
  143. TiledMma tiled_mma,
  144. int thread_idx,
  145. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  146. ) {
  147. auto [n_block, bidh, bidb] = block_coord;
  148. Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{}));
  149. Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{}));
  150. Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{}));
  151. Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{}));
  152. auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma);
  153. auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx);
  154. Tensor tdVrdV_out = make_tensor_like<Element>(tdVrdV);
  155. flash::convert_type_out(tdVrdV, tdVrdV_out);
  156. Tensor tdKrdK_out = make_tensor_like<Element>(tdKrdK);
  157. flash::convert_type_out(tdKrdK, tdKrdK_out);
  158. Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  159. Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  160. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); }
  161. Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  162. Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  163. // Make sure all WGs have finished reading K and V
  164. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  165. cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
  166. cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
  167. if constexpr (Use_TMA) {
  168. cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
  169. cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  170. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  171. Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK);
  172. Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK);
  173. Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  174. Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  175. auto block_tma_dK = params.tma_store_dK.get_slice(_0{});
  176. auto block_tma_dV = params.tma_store_dV.get_slice(_0{});
  177. Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
  178. Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
  179. Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
  180. Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
  181. int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
  182. if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
  183. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  184. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  185. if (cute::elect_one_sync()) {
  186. cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);
  187. cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);
  188. tma_store_arrive();
  189. }
  190. }
  191. tma_store_wait<0>();
  192. // // Tell warp 0 that smem_k and smem_v are ready
  193. // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::KVEmpty) /*id*/);
  194. } else {
  195. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  196. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  197. flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
  198. bool const is_varlen = Varlen && params.cu_seqlens;
  199. Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
  200. Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  201. Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
  202. Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  203. GmemTiledCopydKV gmem_tiled_copy_dKV;
  204. auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
  205. Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
  206. Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
  207. Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
  208. Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)
  209. Tensor tdKVrdV = make_fragment_like(tdKVgdV);
  210. Tensor tdKVrdK = make_fragment_like(tdKVgdK);
  211. Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  212. // Repeat the partitioning with identity layouts
  213. Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
  214. Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
  215. #pragma unroll
  216. for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
  217. // Need to check OOB when reading from smem if kBlockN isn't evenly tiled
  218. static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
  219. flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
  220. gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdKV, kBlockN);
  221. flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
  222. gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdKV, kBlockN);
  223. // // Tell warp 0 that smem_k and smem_v are ready
  224. // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v
  225. // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::KVEmpty) /*id*/);
  226. // Construct identity layout for gdKV
  227. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  228. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  229. gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
  230. );
  231. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  232. gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
  233. );
  234. }
  235. }
  236. CUTLASS_DEVICE void
  237. store_tail() {
  238. // if constexpr (Use_TMA) { tma_store_wait<0>(); }
  239. }
  240. // Write 0 to dK and dV
  241. CUTLASS_DEVICE void
  242. store_zero(
  243. Params const& params,
  244. int thread_idx,
  245. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  246. ) {
  247. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  248. auto [n_block, bidh, bidb] = block_coord;
  249. flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
  250. bool const is_varlen = Varlen && params.cu_seqlens;
  251. Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
  252. Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  253. Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
  254. Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  255. GmemTiledCopydKV gmem_tiled_copy_dKV;
  256. auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
  257. Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
  258. Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
  259. Tensor tdKVrdKV = make_fragment_like(tdKVgdK);
  260. clear(tdKVrdKV);
  261. // Construct identity layout for gdKV
  262. Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  263. // Repeat the partitioning with identity layouts
  264. Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
  265. Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
  266. #pragma unroll
  267. for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
  268. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  269. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  270. gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN
  271. );
  272. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  273. gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN
  274. );
  275. }
  276. };
  277. template <class TileShape_MNK_, class ElementAccum, class ArchTag_,
  278. int NumEpilogueThreads_, bool Varlen_, bool Deterministic>
  279. struct CollectiveEpilogueBwdGQA {
  280. using TileShape_MNK = TileShape_MNK_;
  281. using Element = ElementAccum;
  282. using ArchTag = ArchTag_;
  283. static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
  284. static constexpr bool Varlen = Varlen_;
  285. static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90;
  286. static_assert(ArchTag::kMinComputeCapability >= 80);
  287. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  288. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  289. static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp");
  290. static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup;
  291. // Thread layout, 256 or 384 threads per row
  292. // We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ
  293. using R2SLayoutAtomdKVaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumWarpGroups>>>;
  294. using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdKVaccum{},
  295. Layout<Shape < _4>>{})); // Val layout, 4 vals per store
  296. // For Sm80
  297. using R2GLayoutAtomdKVaccum = Layout<Shape<Int<NumEpilogueThreads>>>;
  298. using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2GLayoutAtomdKVaccum{},
  299. Layout<Shape < _1>>{})); // Val layout, 1 vals per store
  300. using SmemLayoutdKVaccum = Layout<Shape<Int<kBlockN * kHeadDim / NumWarpGroups>, Int<NumWarpGroups>>>;
  301. using SmemLayoutdKVaccumFlat = Layout<Shape<Int<kBlockN * kHeadDim>>>;
  302. // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we
  303. // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue.
  304. static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256);
  305. struct TensorStorageTMA : cute::aligned_struct<SmemAlignment> {
  306. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdKVaccum>, SmemAlignment> smem_dkv;
  307. };
  308. struct TensorStorageSTG {
  309. cute::array<ElementAccum, 0> smem_dkv;
  310. };
  311. using TensorStorage = std::conditional_t<Use_TMA, TensorStorageTMA, TensorStorageSTG>;
  312. using ShapedKV = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_k_rounded * d, head, batch)
  313. using StridedKV = cute::Stride<_1, int64_t, int64_t>;
  314. // Host side kernel arguments
  315. struct Arguments {
  316. ElementAccum* ptr_dKaccum;
  317. ShapedKV const shape_dKaccum;
  318. StridedKV const stride_dKaccum;
  319. ElementAccum* ptr_dVaccum;
  320. StridedKV const stride_dVaccum;
  321. int num_heads_q;
  322. int* dk_semaphore;
  323. int* dv_semaphore;
  324. int const* cu_seqlens;
  325. int const* seqused;
  326. };
  327. // Device side kernel params
  328. struct Params {
  329. ElementAccum* ptr_dKaccum;
  330. ShapedKV const shape_dKaccum;
  331. StridedKV const stride_dKaccum;
  332. ElementAccum* ptr_dVaccum;
  333. StridedKV const stride_dVaccum;
  334. cutlass::FastDivmod qhead_per_khead_divmod;
  335. int* dk_semaphore;
  336. int* dv_semaphore;
  337. int const* cu_seqlens = nullptr;
  338. int const* seqused = nullptr;
  339. };
  340. static Params
  341. to_underlying_arguments(Arguments const& args) {
  342. if constexpr (Deterministic) {
  343. assert(args.dk_semaphore != nullptr);
  344. assert(args.dv_semaphore != nullptr);
  345. }
  346. return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.stride_dVaccum,
  347. cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))),
  348. args.dk_semaphore, args.dv_semaphore,
  349. args.cu_seqlens, args.seqused};
  350. }
  351. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  352. CUTLASS_DEVICE
  353. static void prefetch_tma_descriptors(Params const& params) {
  354. }
  355. template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
  356. CUTLASS_DEVICE void
  357. store(Params const& params,
  358. FrgTensorO const& tdKrdK,
  359. FrgTensorO const& tdVrdV,
  360. SharedStorage& shared_storage,
  361. TiledMma tiled_mma,
  362. int thread_idx,
  363. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  364. ) {
  365. auto [n_block, bidh, bidb] = block_coord;
  366. int bidh_idx_in_group;
  367. int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh);
  368. Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{});
  369. Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{});
  370. static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum);
  371. flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused};
  372. bool const is_varlen = Varlen && params.cu_seqlens;
  373. Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
  374. Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dKaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
  375. Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
  376. Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
  377. R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum;
  378. auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
  379. Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV);
  380. // Only used if !Use_TMA
  381. R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum;
  382. auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
  383. // Make sure all WGs have finished reading K and V, otherwise we get racy dQ
  384. // because smem_q could be changed.
  385. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  386. if constexpr (Use_TMA) {
  387. Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N)
  388. cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum);
  389. }
  390. // int const num_batch = params.num_batch;
  391. int const num_batch = get<2>(params.shape_dKaccum);
  392. int const num_head_kv = get<1>(params.shape_dKaccum);
  393. int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv;
  394. using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
  395. // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
  396. if constexpr (Deterministic) {
  397. Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
  398. }
  399. // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);}
  400. if constexpr (Use_TMA) {
  401. cutlass::arch::fence_view_async_shared();
  402. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  403. if (thread_idx == 0) {
  404. SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
  405. tma_store_arrive();
  406. tma_store_wait<0>();
  407. }
  408. } else {
  409. Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV);
  410. Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum);
  411. static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic)));
  412. #pragma unroll
  413. for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); }
  414. }
  415. if constexpr (Deterministic) {
  416. Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
  417. }
  418. if constexpr (Use_TMA) {
  419. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  420. Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N)
  421. cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum);
  422. }
  423. lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv;
  424. // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
  425. if constexpr (Deterministic) {
  426. Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
  427. }
  428. // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);}
  429. if constexpr (Use_TMA) {
  430. cutlass::arch::fence_view_async_shared();
  431. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  432. if (thread_idx == 0) {
  433. SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
  434. tma_store_arrive();
  435. tma_store_wait<0>();
  436. }
  437. } else {
  438. Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK);
  439. Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum);
  440. static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic)));
  441. #pragma unroll
  442. for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); }
  443. }
  444. if constexpr (Deterministic) {
  445. Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
  446. }
  447. // // Tell warp 0 that smem_k and smem_v are ready
  448. // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::KVEmpty) /*id*/);
  449. }
  450. CUTLASS_DEVICE void
  451. store_tail() {
  452. }
  453. // Write 0 to dK and dV
  454. CUTLASS_DEVICE void
  455. store_zero(
  456. Params const& params,
  457. int thread_idx,
  458. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  459. ) {
  460. // Don't need to do anything since dKaccum and dVaccum are already zero-initialized
  461. }
  462. };
  463. } // namespace flash