1
0

epilogue_bwd_sm90_tma.hpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include "cute/tensor.hpp"
  7. #include "cutlass/gemm/collective/collective_builder.hpp"
  8. #include "named_barrier.hpp"
  9. #include "utils.h"
  10. namespace flash {
  11. using namespace cute;
  12. template <class TileShape_MNK_, class Element_, int NumEpilogueThreads_, bool Varlen_>
  13. struct CollectiveEpilogueBwd {
  14. using TileShape_MNK = TileShape_MNK_;
  15. using Element = Element_;
  16. static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
  17. static constexpr bool Varlen = Varlen_;
  18. using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE;
  19. // These are for storing the output tensor without TMA (e.g., for setting output to zero)
  20. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  21. static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  22. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  23. static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);
  24. static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
  25. using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  26. Stride<Int<kGmemThreadsPerRow>, _1>>;
  27. using GmemTiledCopydKV = decltype(
  28. make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
  29. GmemLayoutAtom{},
  30. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
  31. using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  32. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  33. using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{})));
  34. // If we don't use TMA
  35. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16);
  36. static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
  37. using SmemLayoutAtomdKVSTG =
  38. decltype(composition(Swizzle<kSwizzle, 3, 3>{},
  39. Layout<Shape<Int<8>, Int<kBlockKSmem>>,
  40. Stride<Int<kBlockKSmem>, _1>>{}));
  41. using SmemLayoutAtomdKV = std::conditional_t<!Varlen, SmemLayoutAtomdKVTMA, SmemLayoutAtomdKVSTG>;
  42. using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{})));
  43. using SmemCopyAtomdKV = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
  44. struct TensorStorage : cute::aligned_struct<128> {
  45. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>> smem_dk;
  46. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>> smem_dv;
  47. };
  48. using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
  49. using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
  50. using LayoutdKV = cute::Layout<ShapedKV, StridedKV>;
  51. using TMA_dKV = decltype(make_tma_copy(
  52. GmemTiledCopydKVTMA{},
  53. make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapedKV{}, StridedKV{}),
  54. SmemLayoutdKVTMA{},
  55. select<1, 2>(TileShape_MNK{}),
  56. _1{})); // no mcast for dKV
  57. // Host side kernel arguments
  58. struct Arguments {
  59. Element* ptr_dK;
  60. ShapedKV const shape_dK;
  61. StridedKV const stride_dK;
  62. Element* ptr_dV;
  63. StridedKV const stride_dV;
  64. int const* cu_seqlens = nullptr;
  65. int const* seqused = nullptr;
  66. };
  67. // Device side kernel params
  68. struct Params {
  69. Element* ptr_dK;
  70. ShapedKV const shape_dK;
  71. StridedKV const stride_dK;
  72. Element* ptr_dV;
  73. StridedKV const stride_dV;
  74. TMA_dKV tma_store_dK, tma_store_dV;
  75. int const* cu_seqlens = nullptr;
  76. int const* seqused = nullptr;
  77. };
  78. static Params
  79. to_underlying_arguments(Arguments const& args) {
  80. if constexpr (Varlen) {
  81. assert (args.cu_seqlens != nullptr);
  82. }
  83. Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK);
  84. Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV);
  85. TMA_dKV tma_store_dK = make_tma_copy(
  86. GmemTiledCopydKVTMA{},
  87. mdK,
  88. SmemLayoutdKVTMA{},
  89. select<1, 2>(TileShape_MNK{}),
  90. _1{}); // no mcast for dKV
  91. TMA_dKV tma_store_dV = make_tma_copy(
  92. GmemTiledCopydKVTMA{},
  93. mdV,
  94. SmemLayoutdKVTMA{},
  95. select<1, 2>(TileShape_MNK{}),
  96. _1{}); // no mcast for dKV
  97. return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV,
  98. tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
  99. }
  100. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  101. CUTLASS_DEVICE
  102. static void prefetch_tma_descriptors(Params const& params) {
  103. if constexpr (!Varlen) {
  104. cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor());
  105. cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor());
  106. }
  107. }
  108. template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
  109. CUTLASS_DEVICE void
  110. store(Params const& params,
  111. FrgTensorO const& tdKrdK,
  112. FrgTensorO const& tdVrdV,
  113. SharedStorage& shared_storage,
  114. TiledMma tiled_mma,
  115. int thread_idx,
  116. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  117. ) {
  118. auto [n_block, bidh, bidb] = block_coord;
  119. Tensor sdK = make_tensor(make_smem_ptr(shared_storage.epilogue.smem_dk.data()), SmemLayoutdKV{});
  120. Tensor sdV = make_tensor(make_smem_ptr(shared_storage.epilogue.smem_dv.data()), SmemLayoutdKV{});
  121. auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma);
  122. auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx);
  123. Tensor tdVrdV_out = flash::convert_type<Element>(tdVrdV);
  124. Tensor tdKrdK_out = flash::convert_type<Element>(tdKrdK);
  125. Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  126. Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  127. Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  128. Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  129. // Make sure all WGs have finished reading K and V
  130. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, static_cast<int>(BwdNamedBarriers::KVEmpty) /*id*/);
  131. cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
  132. cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
  133. if constexpr (!Varlen) {
  134. cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
  135. cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  136. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  137. Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK);
  138. Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK);
  139. Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  140. Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  141. auto block_tma_dK = params.tma_store_dK.get_slice(_0{});
  142. auto block_tma_dV = params.tma_store_dV.get_slice(_0{});
  143. Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
  144. Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
  145. Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
  146. Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
  147. int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
  148. if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
  149. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  150. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  151. int const lane_predicate = cute::elect_one_sync();
  152. if (lane_predicate) {
  153. cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);
  154. cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);
  155. tma_store_arrive();
  156. }
  157. }
  158. } else {
  159. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  160. bool const is_varlen = params.cu_seqlens != nullptr;
  161. int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
  162. int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (
  163. params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]
  164. );
  165. Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
  166. Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  167. Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
  168. Tensor gdV = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  169. GmemTiledCopydKV gmem_tiled_copy_dKV;
  170. auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
  171. Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
  172. Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
  173. Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
  174. Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)
  175. Tensor tdKVrdV = make_fragment_like(tdKVgdV);
  176. Tensor tdKVrdK = make_fragment_like(tdKVgdK);
  177. cute::copy(gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV);
  178. cute::copy(gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK);
  179. // Construct identity layout for gdKV
  180. Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  181. // Repeat the partitioning with identity layouts
  182. Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
  183. Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
  184. #pragma unroll
  185. for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
  186. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  187. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  188. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  189. gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen - n_block * kBlockN
  190. );
  191. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  192. gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen - n_block * kBlockN
  193. );
  194. }
  195. }
  196. CUTLASS_DEVICE void
  197. store_tail() {
  198. if constexpr (!Varlen) { tma_store_wait<0>(); }
  199. }
  200. // Write 0 to dK and dV
  201. CUTLASS_DEVICE void
  202. store_zero(
  203. Params const& params,
  204. int thread_idx,
  205. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  206. ) {
  207. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  208. auto [n_block, bidh, bidb] = block_coord;
  209. bool const is_varlen = Varlen && params.cu_seqlens != nullptr;
  210. int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
  211. int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset);
  212. Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
  213. Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  214. Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
  215. Tensor gdV = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
  216. GmemTiledCopydKV gmem_tiled_copy_dKV;
  217. auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
  218. Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
  219. Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
  220. Tensor tdKVrdKV = make_fragment_like(tdKVgdK);
  221. clear(tdKVrdKV);
  222. // Construct identity layout for gdKV
  223. Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  224. // Repeat the partitioning with identity layouts
  225. Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
  226. Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
  227. #pragma unroll
  228. for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
  229. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  230. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  231. gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen - n_block * kBlockN
  232. );
  233. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  234. gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen - n_block * kBlockN
  235. );
  236. }
  237. };
  238. } // namespace flash