epilogue_fwd_sm90_tma.hpp 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include "cute/tensor.hpp"
  7. #include "cutlass/gemm/collective/collective_builder.hpp"
  8. #include "named_barrier.hpp"
  9. #include "utils.h"
  10. namespace flash {
  11. using namespace cute;
  12. // template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_>
  13. template <typename Ktraits, typename Seqlen_traits>
  14. struct CollectiveEpilogueFwd {
  15. using PrecType = typename Ktraits::Element;
  16. using Element = decltype(cute::conditional_return<is_same_v<PrecType, cutlass::float_e4m3_t>>(cutlass::half_t{}, PrecType{}));
  17. static constexpr int kBlockM = Ktraits::kBlockM;
  18. static constexpr int kBlockN = Ktraits::kBlockN;
  19. static constexpr int kHeadDim = Ktraits::kHeadDim;
  20. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  21. static constexpr int kNWarps = Ktraits::kNWarps;
  22. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  23. static constexpr bool Is_WS = kNWarps >= 12;
  24. static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
  25. static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;
  26. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  27. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  28. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  29. using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
  30. using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>;
  31. using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
  32. using TMA_O = decltype(make_tma_copy(
  33. GmemTiledCopyOTMA{},
  34. make_tensor(
  35. make_gmem_ptr(static_cast<Element*>(nullptr)),
  36. typename Seqlen_traits::ShapeT{},
  37. typename Seqlen_traits::StrideT{}
  38. ),
  39. SmemLayoutO{},
  40. select<0, 2>(TileShape_MNK{}),
  41. _1{})); // no mcast for O
  42. // These are for storing the output tensor without TMA (e.g., for setting output to zero and var-seq-len)
  43. static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
  44. static_assert(kHeadDim % kNumVecElem == 0);
  45. static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
  46. static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
  47. static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
  48. using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
  49. using TiledCopyOThrLayout = decltype(cute::make_layout(
  50. cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
  51. LayoutRight{}));
  52. using TiledCopyOValLayout = decltype(cute::make_layout(
  53. cute::make_shape(_1{}, Int<kNumVecElem>{}),
  54. LayoutRight{}));
  55. using TiledCopyO = decltype(make_tiled_copy(
  56. TiledCopyOAtom{},
  57. TiledCopyOThrLayout{}, // Thr layout
  58. TiledCopyOValLayout{} // Val layout
  59. ));
  60. // Host side kernel arguments
  61. struct Arguments {
  62. Element* ptr_O;
  63. typename Seqlen_traits::LayoutT const layout_O;
  64. float* ptr_LSE;
  65. typename Seqlen_traits::LayoutLseT const layout_LSE;
  66. };
  67. // Device side kernel params
  68. struct Params {
  69. Element* ptr_O;
  70. typename Seqlen_traits::LayoutT const layout_O;
  71. float* ptr_LSE;
  72. typename Seqlen_traits::LayoutLseT const layout_LSE;
  73. TMA_O tma_store_O;
  74. };
  75. static Params
  76. to_underlying_arguments(Arguments const& args) {
  77. Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.layout_O);
  78. TMA_O tma_store_O = make_tma_copy(
  79. GmemTiledCopyOTMA{},
  80. mO,
  81. SmemLayoutO{},
  82. select<0, 2>(TileShape_MNK{}),
  83. _1{}); // no mcast for O
  84. return {args.ptr_O, args.layout_O, args.ptr_LSE, args.layout_LSE, tma_store_O};
  85. }
  86. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  87. CUTLASS_DEVICE
  88. static void prefetch_tma_descriptors(Params const& epilogue_params) {
  89. if constexpr (!Seqlen_traits::kUseVarSeqLen) {
  90. cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
  91. }
  92. }
  93. template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
  94. CUTLASS_DEVICE void
  95. store(Params const& epilogue_params,
  96. FrgTensorO const& tOrO,
  97. FrgTensorLSE const& lse,
  98. SharedStorage& shared_storage,
  99. TiledMma tiled_mma,
  100. int thread_idx,
  101. cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
  102. const Seqlen_traits& seqlen_traits_q
  103. ) {
  104. auto [m_block, bidh, bidb] = block_coord;
  105. Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
  106. auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
  107. auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
  108. Tensor tOrO_out = flash::convert_type<Element>(tOrO);
  109. Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  110. Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  111. // Make sure all WGs have finished reading V
  112. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);
  113. cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
  114. cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
  115. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
  116. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  117. Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
  118. Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
  119. mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
  120. Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
  121. auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
  122. Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
  123. static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
  124. static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
  125. // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
  126. Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
  127. CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
  128. if (get<1>(taccOcO_row(_0{})) == 0) {
  129. #pragma unroll
  130. for (int mi = 0; mi < size(lse); ++mi) {
  131. const int row = get<0>(taccOcO_row(mi));
  132. if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(row) = lse(mi); }
  133. }
  134. }
  135. int write_warp_idx = kNWarps - 1;
  136. if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
  137. cutlass::arch::NamedBarrier::sync(
  138. NumMmaThreads + cutlass::NumThreadsPerWarp,
  139. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
  140. );
  141. }
  142. TiledCopyO gmem_tiled_copy_O;
  143. flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
  144. epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O,
  145. epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO,
  146. m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
  147. );
  148. }
  149. CUTLASS_DEVICE void
  150. store_tail() {
  151. tma_store_wait<0>();
  152. }
  153. // Write 0 to output and -inf to LSE
  154. template<typename SharedStorage>
  155. CUTLASS_DEVICE void
  156. store_zero(
  157. Params const& epilogue_params,
  158. SharedStorage& shared_storage,
  159. int thread_idx,
  160. cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
  161. const Seqlen_traits& seqlen_traits_q
  162. ) {
  163. auto [m_block, bidh, bidb] = block_coord;
  164. Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O);
  165. Tensor gO = seqlen_traits_q.get_local_tile_tensor(
  166. mO, select<0, 2>(TileShape_MNK{}), bidh, bidb
  167. )(_, _, m_block); // (M, K)
  168. Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
  169. Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
  170. mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
  171. TiledCopyO gmem_tiled_copy_O;
  172. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  173. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  174. Tensor tOrO = make_fragment_like(tOgO);
  175. clear(tOrO);
  176. // Construct identity layout for sO
  177. Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  178. // Repeat the partitioning with identity layouts
  179. Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
  180. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
  181. #pragma unroll
  182. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
  183. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  184. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  185. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
  186. );
  187. static_assert(kBlockM <= NumMmaThreads);
  188. if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
  189. }
  190. };
  191. } // namespace flash