flash_bwd_postprocess_kernel.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/array.h>
  8. #include <cutlass/numeric_types.h>
  9. #include <cutlass/numeric_conversion.h>
  10. #include "cutlass/arch/barrier.h"
  11. #include "seqlen.h"
  12. #include "utils.h"
  13. namespace flash {
  14. using namespace cute;
  15. template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class TiledMma, bool dQ_swapAB>
  16. class FlashAttnBwdPostprocessConvertdQ {
  17. public:
  18. // Type Aliases
  19. using TileShape_MK = TileShape_MK_;
  20. using ArchTag = ArchTag_;
  21. static_assert(ArchTag::kMinComputeCapability >= 90);
  22. static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
  23. static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
  24. static constexpr int kBlockM = get<0>(TileShape_MK{});
  25. static constexpr int kHeadDim = get<1>(TileShape_MK{});
  26. static_assert(kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup");
  27. static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup;
  28. using R2SLayoutAtomdQaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumdQWarpGgroups>>>;
  29. using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},
  30. Layout<Shape < _4>>{})); // Val layout, 4 vals per read
  31. static constexpr int SmemdQaccumSize = size(TileShape_MK{});
  32. using SmemLayoutdQaccumFlat = Layout<Shape<Int<SmemdQaccumSize>>>;
  33. using SmemLayoutdQaccum = Layout<Shape<Int<kBlockM * kHeadDim / NumdQWarpGgroups>, Int<NumdQWarpGgroups>>>;
  34. // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
  35. // then setting kBlockKSmem to 32 will cause "Static shape_div failure".
  36. // We want to treat it as 64 x 48, so kBlockKSmem should be 16.
  37. static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{});
  38. static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16);
  39. static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
  40. using SmemLayoutAtomdQ =
  41. decltype(composition(Swizzle<kSwizzle, 3, 3>{},
  42. Layout<Shape<Int<8>, Int<kBlockKSmem>>,
  43. Stride<Int<kBlockKSmem>, _1>>{}));
  44. using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{}));
  45. using SmemLayoutdQt =
  46. decltype(cute::composition(SmemLayoutdQ{},
  47. make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})),
  48. make_stride(Int<get<0>(TileShape_MK{})>{}, _1{}))));
  49. using SmemCopyAtomdQ = Copy_Atom<
  50. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  51. Element>;
  52. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  53. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  54. static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock));
  55. static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
  56. using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  57. Stride<Int<kGmemThreadsPerRow>, _1>>;
  58. using GmemTiledCopy = decltype(
  59. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  60. GmemLayoutAtom{},
  61. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
  62. struct SharedStorage : cute::aligned_struct<128> {
  63. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>> smem_dqacc;
  64. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  65. alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum;
  66. };
  67. static constexpr int SharedStorageSize = sizeof(SharedStorage);
  68. using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
  69. using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;
  70. using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
  71. using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
  72. // Device side arguments
  73. struct Arguments {
  74. ElementAccum const* ptr_dQaccum;
  75. ShapedQaccum const shape_dQaccum;
  76. StridedQaccum const stride_dQaccum;
  77. Element* ptr_dQ;
  78. ShapedQ const shape_dQ;
  79. StridedQ const stride_dQ;
  80. float const softmax_scale;
  81. int const* cu_seqlens = nullptr;
  82. int const* seqused = nullptr;
  83. };
  84. // Kernel entry point API
  85. struct Params {
  86. ElementAccum const* ptr_dQaccum;
  87. ShapedQaccum const shape_dQaccum;
  88. StridedQaccum const stride_dQaccum;
  89. Element* ptr_dQ;
  90. ShapedQ const shape_dQ;
  91. StridedQ const stride_dQ;
  92. float const softmax_scale;
  93. int const* cu_seqlens = nullptr;
  94. int const* seqused = nullptr;
  95. };
  96. // Convert to underlying arguments. In this case, a simple copy for the aliased type.
  97. static
  98. Params
  99. to_underlying_arguments(Arguments const& args) {
  100. return {
  101. args.ptr_dQaccum,
  102. args.shape_dQaccum,
  103. args.stride_dQaccum,
  104. args.ptr_dQ,
  105. args.shape_dQ,
  106. args.stride_dQ,
  107. args.softmax_scale,
  108. args.cu_seqlens,
  109. args.seqused
  110. };
  111. }
  112. CUTLASS_DEVICE
  113. void
  114. operator()(Params const& params, char* smem_buf) {
  115. static constexpr int kBlockM = get<0>(TileShape_MK{});
  116. SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
  117. Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{});
  118. Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{});
  119. Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{});
  120. Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{});
  121. int const thread_idx = threadIdx.x;
  122. int const m_block = blockIdx.x;
  123. int const bidh = blockIdx.y;
  124. int const bidb = blockIdx.z;
  125. flash::SeqlenInfo<true /*Varlen*/, kBlockM> seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused);
  126. bool const is_varlen = params.cu_seqlens;
  127. if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; }
  128. // Step 1: Bulk copy to load dQaccum from gmem to smem
  129. Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum const*>(params.ptr_dQaccum)),
  130. params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
  131. Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block)); // (M * K)
  132. static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v<ElementAccum> / 8);
  133. auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};
  134. // if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); }
  135. if (thread_idx == 0) {
  136. shared_storage.barrier_dQaccum.init(1 /*numThreads*/);
  137. shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
  138. copy(bulk_copy.with(*reinterpret_cast<uint64_t*>(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat);
  139. }
  140. __syncthreads();
  141. shared_storage.barrier_dQaccum.wait(0);
  142. // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); }
  143. // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16
  144. R2STiledCopydQaccum s2r_tiled_copy_dQaccum;
  145. auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx);
  146. Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum);
  147. TiledMma tiled_mma_dQ;
  148. Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{}));
  149. // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); }
  150. // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); }
  151. // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); }
  152. CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum));
  153. Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum);
  154. cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
  155. #pragma unroll
  156. for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; }
  157. // Convert tdQrdQ from fp32 to fp16
  158. Tensor rdQ = make_tensor_like<Element>(taccdQrdQaccum);
  159. flash::convert_type_out(taccdQrdQaccum, rdQ);
  160. // Step 3: Copy dQ from register to smem
  161. auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ);
  162. auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx);
  163. Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
  164. // if (cute::thread0()) { print(smem_tiled_copy_dQ); }
  165. // if (cute::thread0()) { print(smem_thr_copy_dQ); }
  166. // if (cute::thread0()) { print(sdQ); }
  167. Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return<!dQ_swapAB>(sdQ, sdQt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  168. cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
  169. __syncthreads();
  170. // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
  171. Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0);
  172. Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
  173. GmemTiledCopy gmem_tiled_copy_dQ;
  174. auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx);
  175. Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  176. Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
  177. Tensor tdQrdQ = make_fragment_like(tdQsdQ);
  178. cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
  179. // Step 5: Copy dQ from register to gmem
  180. Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{}));
  181. Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
  182. #pragma unroll
  183. for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); }
  184. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  185. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  186. gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, seqlen_info.seqlen - m_block * kBlockM
  187. );
  188. }
  189. };
  190. } // namespace flash