epilogue_fwd_sm90_tma.hpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include "cute/tensor.hpp"
  7. #include "cutlass/gemm/collective/collective_builder.hpp"
  8. #include "named_barrier.hpp"
  9. #include "utils.h"
  10. namespace flash {
  11. using namespace cute;
  12. // template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_>
  13. template <typename Ktraits, typename Seqlen_traits>
  14. struct CollectiveEpilogueFwd {
  15. using Element = typename Ktraits::OutputType;
  16. static constexpr int kBlockM = Ktraits::kBlockM;
  17. static constexpr int kBlockN = Ktraits::kBlockN;
  18. static constexpr int kHeadDim = Ktraits::kHeadDim;
  19. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  20. static constexpr int kNWarps = Ktraits::kNWarps;
  21. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  22. static constexpr bool Is_WS = kNWarps >= 12;
  23. static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
  24. static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;
  25. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  26. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  27. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  28. using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
  29. using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>;
  30. using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
  31. using TMA_O = decltype(make_tma_copy(
  32. GmemTiledCopyOTMA{},
  33. make_tensor(
  34. make_gmem_ptr(static_cast<Element*>(nullptr)),
  35. typename Seqlen_traits::ShapeT{},
  36. typename Seqlen_traits::StrideT{}
  37. ),
  38. SmemLayoutO{},
  39. select<0, 2>(TileShape_MNK{}),
  40. _1{})); // no mcast for O
  41. // These are for storing the output tensor without TMA (e.g., for setting output to zero and var-seq-len)
  42. static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
  43. static_assert(kHeadDim % kNumVecElem == 0);
  44. static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
  45. static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
  46. static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
  47. using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
  48. using TiledCopyOThrLayout = decltype(cute::make_layout(
  49. cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
  50. LayoutRight{}));
  51. using TiledCopyOValLayout = decltype(cute::make_layout(
  52. cute::make_shape(_1{}, Int<kNumVecElem>{}),
  53. LayoutRight{}));
  54. using TiledCopyO = decltype(make_tiled_copy(
  55. TiledCopyOAtom{},
  56. TiledCopyOThrLayout{}, // Thr layout
  57. TiledCopyOValLayout{} // Val layout
  58. ));
  59. // used for rmem -> smem O copy in fp8 kernel to undo column permutation
  60. using ThreadLayoutrO = Layout<Shape<_8, Int<kBlockM/16>, _4, _1>,
  61. Stride<_4, _32, _1, _0>>;
  62. using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
  63. Stride<_0, _2, Stride<_4, _1>, _8>>;
  64. using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint16_t>, Element>{},
  65. ThreadLayoutrO{}, ValueLayoutrO{}));
  66. using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
  67. using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout<TiledCopyShaperO>{}));
  68. // Host side kernel arguments
  69. struct Arguments {
  70. Element* ptr_O;
  71. typename Seqlen_traits::LayoutT const layout_O;
  72. float* ptr_LSE;
  73. typename Seqlen_traits::LayoutLseT const layout_LSE;
  74. };
  75. // Device side kernel params
  76. struct Params {
  77. Element* ptr_O;
  78. typename Seqlen_traits::LayoutT const layout_O;
  79. float* ptr_LSE;
  80. typename Seqlen_traits::LayoutLseT const layout_LSE;
  81. TMA_O tma_store_O;
  82. };
  83. static Params
  84. to_underlying_arguments(Arguments const& args) {
  85. Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.layout_O);
  86. TMA_O tma_store_O = make_tma_copy(
  87. GmemTiledCopyOTMA{},
  88. mO,
  89. SmemLayoutO{},
  90. select<0, 2>(TileShape_MNK{}),
  91. _1{}); // no mcast for O
  92. return {args.ptr_O, args.layout_O, args.ptr_LSE, args.layout_LSE, tma_store_O};
  93. }
  94. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  95. CUTLASS_DEVICE
  96. static void prefetch_tma_descriptors(Params const& epilogue_params) {
  97. if constexpr (!Seqlen_traits::kUseVarSeqLen) {
  98. cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
  99. }
  100. }
  101. template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
  102. CUTLASS_DEVICE void
  103. store(Params const& epilogue_params,
  104. FrgTensorO const& tOrO,
  105. FrgTensorLSE const& lse,
  106. SharedStorage& shared_storage,
  107. TiledMma tiled_mma,
  108. int thread_idx,
  109. cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
  110. const Seqlen_traits& seqlen_traits_q
  111. ) {
  112. auto [m_block, bidh, bidb] = block_coord;
  113. Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
  114. auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
  115. auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
  116. Tensor tOrO_out = flash::convert_type<Element>(tOrO);
  117. Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  118. Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  119. // Make sure all WGs have finished reading V
  120. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);
  121. cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
  122. cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
  123. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
  124. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  125. Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
  126. Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
  127. mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
  128. Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
  129. auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
  130. Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
  131. static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
  132. static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
  133. // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
  134. Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
  135. CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
  136. if (get<1>(taccOcO_row(_0{})) == 0) {
  137. #pragma unroll
  138. for (int mi = 0; mi < size(lse); ++mi) {
  139. const int row = get<0>(taccOcO_row(mi));
  140. if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(row) = lse(mi); }
  141. }
  142. }
  143. int write_warp_idx = kNWarps - 1;
  144. if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
  145. cutlass::arch::NamedBarrier::sync(
  146. NumMmaThreads + cutlass::NumThreadsPerWarp,
  147. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
  148. );
  149. }
  150. TiledCopyO gmem_tiled_copy_O;
  151. flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
  152. epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O,
  153. epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO,
  154. m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
  155. );
  156. }
  157. template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
  158. CUTLASS_DEVICE void
  159. store_fp8(Params const& epilogue_params,
  160. FrgTensorO const& tOrO,
  161. FrgTensorLSE const& lse,
  162. SharedStorage& shared_storage,
  163. TiledMma tiled_mma,
  164. int thread_idx,
  165. cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
  166. const Seqlen_traits& seqlen_traits_q
  167. ) {
  168. // using SmemLayoutrO = typename Ktraits::SmemLayoutrO;
  169. // using TiledCopyrO = typename Ktraits::TiledCopyrO;
  170. auto [m_block, bidh, bidb] = block_coord;
  171. TiledCopyrO rmem_tiled_copy_O;
  172. Tensor sOacc = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutrO{});
  173. auto rmem_thr_copy_O = rmem_tiled_copy_O.get_thread_slice(thread_idx);
  174. Tensor taccOsO = rmem_thr_copy_O.partition_D(sOacc);
  175. Tensor tOrO_out = flash::convert_type<Element>(tOrO); // Element is Ktraits::OutputType
  176. Tensor taccOrO = make_tensor(tOrO_out.data(), shape(taccOsO));
  177. // Make sure all WGs have finished reading V
  178. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);
  179. cute::copy(rmem_tiled_copy_O, taccOrO, taccOsO);
  180. cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
  181. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
  182. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  183. Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
  184. Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
  185. mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
  186. Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
  187. auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
  188. Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
  189. static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
  190. static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
  191. // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
  192. Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
  193. CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
  194. int const seqlen_q = [&] {
  195. if constexpr(Seqlen_traits::kUseVarSeqLen) { return seqlen_traits_q.actual_seq_len; }
  196. else { return shape<2>(epilogue_params.layout_LSE); }
  197. }();
  198. if (get<1>(taccOcO_row(_0{})) == 0) {
  199. #pragma unroll
  200. for (int mi = 0; mi < size(lse); ++mi) {
  201. const int row = get<0>(taccOcO_row(mi));
  202. if (row < seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
  203. }
  204. }
  205. int write_warp_idx = kNWarps - 1;
  206. if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
  207. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp,
  208. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  209. }
  210. TiledCopyO gmem_tiled_copy_O;
  211. Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
  212. flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
  213. epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O,
  214. epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO,
  215. m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
  216. );
  217. }
  218. CUTLASS_DEVICE void
  219. store_tail() {
  220. tma_store_wait<0>();
  221. }
  222. // Write 0 to output and -inf to LSE
  223. template<typename SharedStorage>
  224. CUTLASS_DEVICE void
  225. store_zero(
  226. Params const& epilogue_params,
  227. SharedStorage& shared_storage,
  228. int thread_idx,
  229. cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
  230. const Seqlen_traits& seqlen_traits_q
  231. ) {
  232. auto [m_block, bidh, bidb] = block_coord;
  233. Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O);
  234. Tensor gO = seqlen_traits_q.get_local_tile_tensor(
  235. mO, select<0, 2>(TileShape_MNK{}), bidh, bidb
  236. )(_, _, m_block); // (M, K)
  237. Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
  238. Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
  239. mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
  240. TiledCopyO gmem_tiled_copy_O;
  241. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  242. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  243. Tensor tOrO = make_fragment_like(tOgO);
  244. clear(tOrO);
  245. // Construct identity layout for sO
  246. Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  247. // Repeat the partitioning with identity layouts
  248. Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
  249. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
  250. #pragma unroll
  251. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
  252. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  253. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  254. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_traits_q.actual_seq_len - m_block * kBlockM
  255. );
  256. static_assert(kBlockM <= NumMmaThreads);
  257. if (thread_idx < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
  258. }
  259. };
  260. } // namespace flash