1
0

flash_bwd_postprocess_kernel.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/array.h>
  8. #include <cutlass/numeric_types.h>
  9. #include <cutlass/numeric_conversion.h>
  10. #include "cutlass/arch/barrier.h"
  11. #include "seqlen.h"
  12. #include "utils.h"
  13. namespace flash {
  14. using namespace cute;
  15. template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class TiledMma, bool dQ_swapAB>
  16. class FlashAttnBwdPostprocessConvertdQ {
  17. public:
  18. // Type Aliases
  19. using TileShape_MK = TileShape_MK_;
  20. using ArchTag = ArchTag_;
  21. static_assert(ArchTag::kMinComputeCapability >= 75);
  22. static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90;
  23. static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
  24. static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
  25. static constexpr int kBlockM = get<0>(TileShape_MK{});
  26. static constexpr int kHeadDim = get<1>(TileShape_MK{});
  27. static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup");
  28. static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup;
  29. using R2SLayoutAtomdQaccum = std::conditional_t<
  30. IsSm90,
  31. Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumdQWarpGgroups>>>,
  32. Layout<Shape<Int<kNThreads>>>
  33. >;
  34. using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},
  35. Layout<Shape<Int<IsSm90 ? 4 : 1>>>{})); // Val layout, 1 or 4 vals per read
  36. using G2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreads>>>;
  37. // UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions
  38. using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, ElementAccum>{}, G2SLayoutAtomdQaccum{},
  39. Layout<Shape<_4>>{})); // Val layout, 4 vals per read
  40. // We don't do bound checking for the gmem -> smem load so we just assert here.
  41. static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0);
  42. static constexpr int SmemdQaccumSize = size(TileShape_MK{});
  43. using SmemLayoutdQaccumFlat = Layout<Shape<Int<SmemdQaccumSize>>>;
  44. using SmemLayoutdQaccum = std::conditional_t<
  45. IsSm90,
  46. Layout<Shape<Int<kBlockM * kHeadDim / NumdQWarpGgroups>, Int<NumdQWarpGgroups>>>,
  47. Layout<Shape<Int<kBlockM * kHeadDim>>>
  48. >;
  49. // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
  50. // then setting kBlockKSmem to 32 will cause "Static shape_div failure".
  51. // We want to treat it as 64 x 48, so kBlockKSmem should be 16.
  52. static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{});
  53. static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16);
  54. static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
  55. using SmemLayoutAtomdQ =
  56. decltype(composition(Swizzle<kSwizzle, 3, 3>{},
  57. Layout<Shape<Int<8>, Int<kBlockKSmem>>,
  58. Stride<Int<kBlockKSmem>, _1>>{}));
  59. using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{}));
  60. using SmemLayoutdQt =
  61. decltype(cute::composition(SmemLayoutdQ{},
  62. make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})),
  63. make_stride(Int<get<0>(TileShape_MK{})>{}, _1{}))));
  64. using SmemCopyAtomdQ = Copy_Atom<
  65. std::conditional_t<
  66. IsSm90,
  67. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  68. AutoVectorizingCopyWithAssumedAlignment<128>
  69. >,
  70. Element>;
  71. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  72. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  73. static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock));
  74. static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
  75. using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  76. Stride<Int<kGmemThreadsPerRow>, _1>>;
  77. using GmemTiledCopy = decltype(
  78. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  79. GmemLayoutAtom{},
  80. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
  81. struct SharedStorage : cute::aligned_struct<128> {
  82. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>> smem_dqacc;
  83. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  84. alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum;
  85. };
  86. static constexpr int SharedStorageSize = sizeof(SharedStorage);
  87. using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
  88. using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;
  89. using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
  90. using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
  91. // Device side arguments
  92. struct Arguments {
  93. ElementAccum const* ptr_dQaccum;
  94. ShapedQaccum const shape_dQaccum;
  95. StridedQaccum const stride_dQaccum;
  96. Element* ptr_dQ;
  97. ShapedQ const shape_dQ;
  98. StridedQ const stride_dQ;
  99. float const softmax_scale;
  100. int const* cu_seqlens = nullptr;
  101. int const* seqused = nullptr;
  102. };
  103. // Kernel entry point API
  104. struct Params {
  105. ElementAccum const* ptr_dQaccum;
  106. ShapedQaccum const shape_dQaccum;
  107. StridedQaccum const stride_dQaccum;
  108. Element* ptr_dQ;
  109. ShapedQ const shape_dQ;
  110. StridedQ const stride_dQ;
  111. float const softmax_scale;
  112. int const* cu_seqlens = nullptr;
  113. int const* seqused = nullptr;
  114. };
  115. // Convert to underlying arguments. In this case, a simple copy for the aliased type.
  116. static
  117. Params
  118. to_underlying_arguments(Arguments const& args) {
  119. return {
  120. args.ptr_dQaccum,
  121. args.shape_dQaccum,
  122. args.stride_dQaccum,
  123. args.ptr_dQ,
  124. args.shape_dQ,
  125. args.stride_dQ,
  126. args.softmax_scale,
  127. args.cu_seqlens,
  128. args.seqused
  129. };
  130. }
  131. CUTLASS_DEVICE
  132. void
  133. operator()(Params const& params, char* smem_buf) {
  134. static constexpr int kBlockM = get<0>(TileShape_MK{});
  135. SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
  136. Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{});
  137. Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{});
  138. Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{});
  139. Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{});
  140. int const thread_idx = threadIdx.x;
  141. int const m_block = blockIdx.x;
  142. int const bidh = blockIdx.y;
  143. int const bidb = blockIdx.z;
  144. flash::SeqlenInfo<true /*Varlen*/, kBlockM> seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused);
  145. bool const is_varlen = params.cu_seqlens;
  146. if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; }
  147. // Step 1: load dQaccum from gmem to smem
  148. Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum const*>(params.ptr_dQaccum)),
  149. params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
  150. Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block)); // (M * K)
  151. if constexpr (IsSm90) { // Use BulkCopy
  152. static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v<ElementAccum> / 8);
  153. auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};
  154. // if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); }
  155. if (thread_idx == 0) {
  156. shared_storage.barrier_dQaccum.init(1 /*numThreads*/);
  157. shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
  158. copy(bulk_copy.with(*reinterpret_cast<uint64_t*>(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat);
  159. }
  160. __syncthreads();
  161. shared_storage.barrier_dQaccum.wait(0);
  162. } else {
  163. G2STiledCopydQaccum g2s_tiled_copy_dQaccum;
  164. auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
  165. Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum);
  166. Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum);
  167. cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s);
  168. __syncthreads();
  169. }
  170. // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); }
  171. // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16
  172. R2STiledCopydQaccum s2r_tiled_copy_dQaccum;
  173. auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx);
  174. Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum);
  175. TiledMma tiled_mma_dQ;
  176. Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{}));
  177. // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); }
  178. // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); }
  179. // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); }
  180. CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum));
  181. Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum);
  182. cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
  183. #pragma unroll
  184. for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; }
  185. // Convert tdQrdQ from fp32 to fp16
  186. Tensor rdQ = make_tensor_like<Element>(taccdQrdQaccum);
  187. flash::convert_type_out(taccdQrdQaccum, rdQ);
  188. // Step 3: Copy dQ from register to smem
  189. auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ);
  190. auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx);
  191. Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
  192. // if (cute::thread0()) { print(smem_tiled_copy_dQ); }
  193. // if (cute::thread0()) { print(smem_thr_copy_dQ); }
  194. // if (cute::thread0()) { print(sdQ); }
  195. Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return<!dQ_swapAB>(sdQ, sdQt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  196. cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
  197. __syncthreads();
  198. // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
  199. Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0);
  200. Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
  201. GmemTiledCopy gmem_tiled_copy_dQ;
  202. auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx);
  203. Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  204. Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
  205. Tensor tdQrdQ = make_fragment_like(tdQsdQ);
  206. Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{}));
  207. Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
  208. #pragma unroll
  209. for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); }
  210. // Need to check OOB when reading from smem if kBlockM isn't evenly tiled
  211. static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
  212. flash::copy</*Is_even_MN=*/EvenM, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
  213. gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM);
  214. // Step 5: Copy dQ from register to gmem
  215. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  216. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  217. gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM)
  218. );
  219. }
  220. };
  221. } // namespace flash