flash_bwd_postprocess_kernel.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/array.h>
  8. #include <cutlass/numeric_types.h>
  9. #include <cutlass/numeric_conversion.h>
  10. #include "cutlass/arch/barrier.h"
  11. #include "utils.h"
  12. namespace flash {
  13. using namespace cute;
  14. template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class SmemLayoutdQaccumTMA,
  15. class TiledMma, bool dQ_swapAB>
  16. class FlashAttnBwdPostprocessConvertdQ {
  17. public:
  18. // Type Aliases
  19. using TileShape_MK = TileShape_MK_;
  20. using ArchTag = ArchTag_;
  21. static_assert(ArchTag::kMinComputeCapability >= 90);
  22. static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
  23. static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
  24. static constexpr int kHeadDim = get<1>(TileShape_MK{});
  25. using R2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreads>>, Stride<_1>>;
  26. using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, R2SLayoutAtomdQaccum{},
  27. Layout<Shape < _4>>{})); // Val layout, 4 vals per read
  28. static constexpr int SmemdQaccumSize = size(TileShape_MK{});
  29. static_assert(size(TileShape_MK{}) == size(SmemLayoutdQaccumTMA{}), "TileShape_MK and SmemLayoutdQaccumTMA must have the same size");
  30. using SmemLayoutdQaccum = Layout<Shape<Int<SmemdQaccumSize>>, Stride<_1>>;
  31. // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
  32. // then setting kBlockKSmem to 32 will cause "Static shape_div failure".
  33. // We want to treat it as 64 x 48, so kBlockKSmem should be 16.
  34. static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{});
  35. static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16);
  36. static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
  37. using SmemLayoutAtomdQ =
  38. decltype(composition(Swizzle<kSwizzle, 3, 3>{},
  39. Layout<Shape<Int<8>, Int<kBlockKSmem>>,
  40. Stride<Int<kBlockKSmem>, _1>>{}));
  41. using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{}));
  42. using SmemLayoutdQt =
  43. decltype(cute::composition(SmemLayoutdQ{},
  44. make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})),
  45. make_stride(Int<get<0>(TileShape_MK{})>{}, _1{}))));
  46. using SmemCopyAtomdQ = Copy_Atom<
  47. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  48. Element>;
  49. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  50. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  51. static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock));
  52. static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
  53. using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  54. Stride<Int<kGmemThreadsPerRow>, _1>>;
  55. using GmemTiledCopy = decltype(
  56. make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
  57. GmemLayoutAtom{},
  58. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
  59. using GmemTiledCopydQaccum = cute::SM90_TMA_LOAD;
  60. struct SharedStorage : cute::aligned_struct<128> {
  61. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccumTMA>, 1024> smem_dqacc;
  62. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  63. alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum;
  64. };
  65. static constexpr int SharedStorageSize = sizeof(SharedStorage);
  66. using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
  67. using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;
  68. using TMA_dQaccum = decltype(make_tma_copy(
  69. GmemTiledCopydQaccum{},
  70. make_tensor(make_gmem_ptr(static_cast<ElementAccum*>(nullptr)), ShapedQ{}, StridedQ{}),
  71. SmemLayoutdQaccumTMA{},
  72. SmemLayoutdQaccumTMA{}.shape(),
  73. _1{})); // no mcast for dQ
  74. // Device side arguments
  75. struct Arguments {
  76. ElementAccum const* ptr_dQaccum;
  77. ShapedQ const shape_dQaccum;
  78. StridedQ const stride_dQaccum;
  79. Element* ptr_dQ;
  80. ShapedQ const shape_dQ;
  81. StridedQ const stride_dQ;
  82. float const softmax_scale;
  83. int const* cu_seqlens = nullptr;
  84. int const* seqused = nullptr;
  85. };
  86. // Kernel entry point API
  87. struct Params {
  88. TMA_dQaccum tma_load_dQaccum;
  89. ShapedQ const shape_dQaccum;
  90. Element* ptr_dQ;
  91. ShapedQ const shape_dQ;
  92. StridedQ const stride_dQ;
  93. float const softmax_scale;
  94. int const* cu_seqlens = nullptr;
  95. int const* seqused = nullptr;
  96. };
  97. // Convert to underlying arguments. In this case, a simple copy for the aliased type.
  98. static
  99. Params
  100. to_underlying_arguments(Arguments const& args) {
  101. Tensor mdQaccum = make_tensor(make_gmem_ptr(args.ptr_dQaccum), args.shape_dQaccum, args.stride_dQaccum);
  102. TMA_dQaccum tma_load_dQaccum = make_tma_copy(
  103. GmemTiledCopydQaccum{},
  104. mdQaccum,
  105. SmemLayoutdQaccumTMA{},
  106. SmemLayoutdQaccumTMA{}.shape(),
  107. _1{}); // no mcast for dQaccum
  108. return {
  109. tma_load_dQaccum,
  110. args.shape_dQaccum,
  111. args.ptr_dQ,
  112. args.shape_dQ,
  113. args.stride_dQ,
  114. args.softmax_scale,
  115. args.cu_seqlens,
  116. args.seqused
  117. };
  118. }
  119. CUTLASS_DEVICE
  120. void
  121. operator()(Params const& params, char* smem_buf) {
  122. static constexpr int kBlockM = get<0>(TileShape_MK{});
  123. SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
  124. Tensor sdQaccumTMA = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumTMA{});
  125. // Tensor sdQaccumTMAnoswizzle = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumTMANoSwizzle{});
  126. Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{});
  127. Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{});
  128. Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{});
  129. int const thread_idx = threadIdx.x;
  130. int const m_block = blockIdx.x;
  131. int const bidh = blockIdx.y;
  132. int const bidb = blockIdx.z;
  133. bool const is_varlen = params.cu_seqlens != nullptr;
  134. int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]);
  135. if (is_varlen && m_block * kBlockM >= seqlen) { return; }
  136. int lane_predicate = cute::elect_one_sync();
  137. int warp_idx = cutlass::canonical_warp_idx_sync();
  138. // Issue Tma Descriptor Prefetch from a single thread
  139. if (warp_idx == 0 && lane_predicate) {
  140. cute::prefetch_tma_descriptor(params.tma_load_dQaccum.get_tma_descriptor());
  141. shared_storage.barrier_dQaccum.init(1 /*numThreads*/);
  142. }
  143. __syncthreads();
  144. // Step 1: TMA to load dQaccum from gmem to smem
  145. // We reshaped dQaccum to have last dimension 32, so the offset needs to be multiplied by kHeadDim / 32
  146. int const offset_padded = !is_varlen ? 0 : ((params.cu_seqlens[bidb] + bidb * 128) / 128 * 128) * (kHeadDim / get<1>(SmemLayoutdQaccumTMA{}.shape()));
  147. Tensor mdQaccum = params.tma_load_dQaccum.get_tma_tensor(params.shape_dQaccum)(_, _, bidh, !is_varlen ? bidb : 0);
  148. Tensor gdQaccum = local_tile(domain_offset(make_coord(offset_padded, _0{}), mdQaccum), SmemLayoutdQaccumTMA{}.shape(), make_coord(m_block, _0{})); // (M, K)
  149. auto block_tma_dQ = params.tma_load_dQaccum.get_slice(_0{});
  150. Tensor tdQgdQaccumTMA = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K)
  151. Tensor tdQsdQaccumTMA = block_tma_dQ.partition_S(sdQaccumTMA); // (TMA, TMA_M, TMA_K)
  152. static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumTMA{}) * cute::sizeof_bits_v<ElementAccum> / 8);
  153. if (warp_idx == 0 && lane_predicate) {
  154. shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
  155. copy(params.tma_load_dQaccum.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_dQaccum), 0 /*mcast_mask*/), tdQgdQaccumTMA, tdQsdQaccumTMA);
  156. }
  157. shared_storage.barrier_dQaccum.wait(0);
  158. // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccumTMA); }
  159. // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccumTMAnoswizzle); }
  160. // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); }
  161. // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16
  162. R2STiledCopydQaccum s2r_tiled_copy_dQaccum;
  163. auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx);
  164. Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum);
  165. TiledMma tiled_mma_dQ;
  166. Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{}));
  167. // if (cute::thread0()) { print(tiled_mma_dQ); printf("\n"); }
  168. // if (cute::thread0()) { print(tdQsdQaccum); }
  169. // if (cute::thread0()) { print(taccdQrdQaccum); }
  170. CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum));
  171. Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum);
  172. cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
  173. #pragma unroll
  174. for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; }
  175. // Convert tdQrdQ from fp32 to fp16
  176. Tensor rdQ = flash::convert_type<Element>(taccdQrdQaccum);
  177. // Step 3: Copy dQ from register to smem
  178. auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ);
  179. auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx);
  180. Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
  181. // if (cute::thread0()) { print(smem_tiled_copy_dQ); }
  182. // if (cute::thread0()) { print(smem_thr_copy_dQ); }
  183. // if (cute::thread0()) { print(sdQ); }
  184. if constexpr (!dQ_swapAB) {
  185. Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  186. cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
  187. } else {
  188. Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  189. cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt);
  190. }
  191. __syncthreads();
  192. // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
  193. int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
  194. Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0);
  195. Tensor gdQ = local_tile(domain_offset(make_coord(offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
  196. GmemTiledCopy gmem_tiled_copy_dQ;
  197. auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx);
  198. Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  199. Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
  200. Tensor tdQrdQ = make_fragment_like(tdQsdQ);
  201. cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
  202. // Step 5: Copy dQ from register to gmem
  203. // Construct identity layout for gdQ
  204. Tensor cdQ = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  205. // Repeat the partitioning with identity layouts
  206. Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
  207. Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
  208. #pragma unroll
  209. for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); }
  210. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  211. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  212. gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, seqlen - m_block * kBlockM
  213. );
  214. }
  215. };
  216. } // namespace flash