1
0

epilogue_fwd_sm90_tma.hpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include "cute/tensor.hpp"
  7. #include "cutlass/gemm/collective/collective_builder.hpp"
  8. #include "cutlass/epilogue/collective/collective_builder.hpp"
  9. #include "named_barrier.hpp"
  10. #include "utils.h"
  11. namespace flash {
  12. using namespace cute;
  13. template <class TileShape_MNK_, class Element_, int NumEpilogueThreads_, bool Varlen_, bool FP8PermuteCol=false>
  14. struct CollectiveEpilogueFwd {
  15. using TileShape_MNK = TileShape_MNK_;
  16. using Element = Element_;
  17. static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
  18. static constexpr bool Varlen = Varlen_;
  19. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  20. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  21. using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
  22. // These are for storing the output tensor without TMA (e.g., for setting output to zero)
  23. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  24. static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  25. static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);
  26. static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
  27. using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  28. Stride<Int<kGmemThreadsPerRow>, _1>>;
  29. using GmemTiledCopyO = decltype(
  30. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  31. GmemLayoutAtom{},
  32. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
  33. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  34. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  35. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  36. using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
  37. using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
  38. using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen_q, head, batch)
  39. // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
  40. using CopyOpR2S = decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>());
  41. using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
  42. struct TensorStorage : cute::aligned_struct<128> {
  43. cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>> smem_o;
  44. };
  45. using TMA_O = decltype(make_tma_copy(
  46. GmemTiledCopyOTMA{},
  47. make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
  48. SmemLayoutO{},
  49. select<0, 2>(TileShape_MNK{}),
  50. _1{})); // no mcast for O
  51. // Host side kernel arguments
  52. struct Arguments {
  53. Element* ptr_O;
  54. ShapeO const shape_O;
  55. StrideO const stride_O;
  56. float* ptr_LSE;
  57. StrideLSE const stride_LSE;
  58. int const* cu_seqlens = nullptr;
  59. int const* seqused = nullptr;
  60. };
  61. // Device side kernel params
  62. struct Params {
  63. Element* ptr_O;
  64. ShapeO const shape_O;
  65. StrideO const stride_O;
  66. float* ptr_LSE;
  67. StrideLSE const stride_LSE;
  68. TMA_O tma_store_O;
  69. int const* cu_seqlens = nullptr;
  70. int const* seqused = nullptr;
  71. };
  72. static Params
  73. to_underlying_arguments(Arguments const& args) {
  74. Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
  75. TMA_O tma_store_O = make_tma_copy(
  76. GmemTiledCopyOTMA{},
  77. mO,
  78. SmemLayoutO{},
  79. select<0, 2>(TileShape_MNK{}),
  80. _1{}); // no mcast for O
  81. if constexpr (Varlen) {
  82. assert(args.cu_seqlens != nullptr);
  83. }
  84. return {args.ptr_O, args.shape_O, args.stride_O, args.ptr_LSE, args.stride_LSE, tma_store_O, args.cu_seqlens, args.seqused};
  85. }
  86. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  87. CUTLASS_DEVICE
  88. static void prefetch_tma_descriptors(Params const& params) {
  89. if constexpr (!Varlen) {
  90. cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor());
  91. }
  92. }
  93. template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
  94. CUTLASS_DEVICE void
  95. store(Params const& params,
  96. FrgTensorO const& tOrO,
  97. FrgTensorLSE const& lse,
  98. SharedStorage& shared_storage,
  99. TiledMma tiled_mma,
  100. int thread_idx,
  101. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  102. ) {
  103. auto [m_block, bidh, bidb] = block_coord;
  104. Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});
  105. // Tensor tOrO_out = flash::convert_type<Element>(tOrO);
  106. Tensor tOrO_out = flash::convert_type_safe<Element>(tOrO);
  107. if constexpr (FP8PermuteCol) { flash::permute_output_fp8_fp16(tOrO_out); }
  108. // Make sure all WGs have finished reading V
  109. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);
  110. auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
  111. auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
  112. Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  113. Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  114. cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
  115. if constexpr (!Varlen) {
  116. cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
  117. cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  118. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  119. } else {
  120. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  121. }
  122. int offset_o = !Varlen ? 0 : params.cu_seqlens[bidb];
  123. int seqlen_o = !Varlen ? get<0>(params.shape_O) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset_o);
  124. auto shape_LSE = select<0, 2, 3>(params.shape_O);
  125. Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !Varlen ? bidb : 0);
  126. Tensor gLSE = local_tile(cute::domain_offset(make_coord(offset_o), mLSE), Shape<Int<kBlockM>>{}, make_coord(m_block));
  127. Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
  128. auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
  129. Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
  130. static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
  131. static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
  132. // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
  133. Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
  134. CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
  135. if (get<1>(taccOcO_row(_0{})) == 0) {
  136. #pragma unroll
  137. for (int mi = 0; mi < size(lse); ++mi) {
  138. const int row = get<0>(taccOcO_row(mi));
  139. if (row < seqlen_o - m_block * kBlockM) { gLSE(row) = lse(mi); }
  140. }
  141. }
  142. if constexpr (!Varlen) {
  143. Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O);
  144. Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  145. auto block_tma_O = params.tma_store_O.get_slice(_0{});
  146. Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
  147. Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
  148. int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
  149. if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
  150. cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
  151. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  152. int lane_predicate = cute::elect_one_sync();
  153. if (lane_predicate) {
  154. cute::copy(params.tma_store_O, tOsO, tOgO);
  155. tma_store_arrive();
  156. }
  157. }
  158. } else { // Don't use TMA since we don't want to overwrite the output of another sequence
  159. Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, params.cu_seqlens == nullptr ? bidb : 0);
  160. Tensor gO = local_tile(cute::domain_offset(make_coord(offset_o, _0{}), mO), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  161. GmemTiledCopyO gmem_tiled_copy_O;
  162. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  163. Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  164. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  165. Tensor tOrO = make_fragment_like(tOsO);
  166. cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
  167. // Construct identity layout for sO
  168. Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  169. // Repeat the partitioning with identity layouts
  170. Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
  171. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
  172. #pragma unroll
  173. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
  174. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  175. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  176. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
  177. );
  178. }
  179. }
  180. CUTLASS_DEVICE void
  181. store_tail() {
  182. if constexpr (!Varlen) { tma_store_wait<0>(); }
  183. }
  184. // Write 0 to output and -inf to LSE
  185. CUTLASS_DEVICE void
  186. store_zero(
  187. Params const& params,
  188. int thread_idx,
  189. cute::tuple<int32_t, int32_t, int32_t> const& block_coord
  190. ) {
  191. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  192. auto [m_block, bidh, bidb] = block_coord;
  193. int offset_o = !Varlen ? 0 : params.cu_seqlens[bidb];
  194. int seqlen_o = !Varlen ? get<0>(params.shape_O) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset_o);
  195. Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !Varlen ? bidb : 0);
  196. Tensor gO = local_tile(cute::domain_offset(make_coord(offset_o, _0{}), mO), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  197. auto shape_LSE = select<0, 2, 3>(params.shape_O);
  198. Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !Varlen ? bidb : 0);
  199. Tensor gLSE = local_tile(cute::domain_offset(make_coord(offset_o), mLSE), Shape<Int<kBlockM>>{}, make_coord(m_block));
  200. GmemTiledCopyO gmem_tiled_copy_O;
  201. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  202. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  203. Tensor tOrO = make_fragment_like(tOgO);
  204. clear(tOrO);
  205. // Construct identity layout for gO
  206. Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  207. // Repeat the partitioning with identity layouts
  208. Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
  209. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
  210. #pragma unroll
  211. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
  212. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  213. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  214. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
  215. );
  216. static_assert(kBlockM <= NumEpilogueThreads);
  217. if (thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM) { gLSE(thread_idx) = -INFINITY; }
  218. }
  219. };
  220. } // namespace flash