mainloop_bwd_sm80.hpp 55 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include "cute/tensor.hpp"
  10. #include "seqlen.h"
  11. #include "mask.h"
  12. #include "softmax.h"
  13. #include "utils.h"
  14. namespace flash {
  15. using namespace cute;
  16. template <int Stages, int Stages_dO, class TileShape_MNK_, class Element_, class ElementAccum_, class ArchTag_,
  17. bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool Deterministic,
  18. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  19. int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=8, int AtomLayoutMdQ=1,
  20. bool V_in_regs=false>
  21. struct CollectiveMainloopBwdSm80 {
  22. static constexpr int kStages = Stages;
  23. static constexpr int kStages_dO = Stages_dO;
  24. static_assert(kStages >= kStages_dO);
  25. using TileShape_MNK = TileShape_MNK_;
  26. using Element = Element_;
  27. using ElementAccum = ElementAccum_;
  28. using ArchTag = ArchTag_;
  29. static constexpr bool Is_causal = Is_causal_;
  30. static constexpr bool Is_local = Is_local_;
  31. static constexpr bool Has_softcap = Has_softcap_;
  32. static constexpr bool Varlen = Varlen_;
  33. using SeqlenInfo_t = flash::SeqlenInfoQK<Varlen, CUTE_STATIC_V(get<0>(TileShape_MNK{}))>;
  34. static constexpr int NumMmaWarps = NumMmaWarpGroups * cutlass::NumWarpsPerWarpGroup;
  35. static constexpr bool SdP_swapAB = SdP_swapAB_;
  36. static constexpr bool dKV_swapAB = dKV_swapAB_;
  37. static constexpr bool dQ_swapAB = dQ_swapAB_;
  38. static constexpr bool Q_dO_same_stages = kStages == kStages_dO;
  39. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  40. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  41. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  42. static_assert(ArchTag::kMinComputeCapability >= 80);
  43. static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;
  44. static constexpr int NumMmaThreads = NumMmaWarps * cutlass::NumThreadsPerWarp;
  45. static constexpr int NumProducerThreads = NumMmaThreads; // For compatibility with TileScheduler
  46. using MMA_Atom_Arch = std::conditional_t<
  47. ArchTag::kMinComputeCapability >= 80,
  48. std::conditional_t<
  49. std::is_same_v<Element, cutlass::half_t>,
  50. MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
  51. MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
  52. >,
  53. MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>
  54. >;
  55. static_assert(NumMmaWarps % AtomLayoutMSdP == 0);
  56. static_assert(NumMmaWarps % AtomLayoutNdKV == 0);
  57. static_assert(NumMmaWarps % AtomLayoutMdQ == 0);
  58. static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarps && SdP_swapAB && !dKV_swapAB;
  59. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarps && AtomLayoutMdQ == NumMmaWarps && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  60. using AtomLayoutSdP = std::conditional_t<
  61. !SdP_swapAB,
  62. Layout<Shape<Int<AtomLayoutMSdP>, Int<NumMmaWarps / AtomLayoutMSdP>, _1>>,
  63. Layout<Shape<Int<NumMmaWarps / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  64. >;
  65. static constexpr bool MmaSdPEvenN = ((!SdP_swapAB ? kBlockN : kBlockM) / size<1>(AtomLayoutSdP{})) % 16 == 0;
  66. using TiledMmaSdP = TiledMMA<
  67. MMA_Atom_Arch,
  68. AtomLayoutSdP,
  69. Tile<Int<16 * CUTE_STATIC_V(size<0>(AtomLayoutSdP{}))>, Int<(MmaSdPEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutSdP{}))>, _16>>;
  70. using AtomLayoutdKV = std::conditional_t<
  71. !dKV_swapAB,
  72. Layout<Shape<Int<AtomLayoutNdKV>, Int<NumMmaWarps / AtomLayoutNdKV>, _1>>,
  73. Layout<Shape<Int<NumMmaWarps / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  74. >;
  75. static constexpr bool MmadKVEvenN = ((!dKV_swapAB ? kHeadDim : kBlockN) / size<1>(AtomLayoutdKV{})) % 16 == 0;
  76. using TiledMmadKV = TiledMMA<
  77. MMA_Atom_Arch,
  78. AtomLayoutdKV,
  79. Tile<Int<16 * CUTE_STATIC_V(size<0>(AtomLayoutdKV{}))>, Int<(MmadKVEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutdKV{}))>, _16>>;
  80. using AtomLayoutdQ = std::conditional_t<
  81. !dQ_swapAB,
  82. Layout<Shape<Int<AtomLayoutMdQ>, Int<NumMmaWarps / AtomLayoutMdQ>, _1>>,
  83. Layout<Shape<Int<NumMmaWarps / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  84. >;
  85. static constexpr bool MmadQEvenN = ((!dQ_swapAB ? kHeadDim : kBlockM) / size<1>(AtomLayoutdQ{})) % 16 == 0;
  86. using TiledMmadQ = TiledMMA<
  87. MMA_Atom_Arch,
  88. AtomLayoutdQ,
  89. Tile<Int<16 * CUTE_STATIC_V(size<0>(AtomLayoutdQ{}))>, Int<(MmadQEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutdQ{}))>, _16>>;
  90. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  91. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  92. // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each
  93. // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.
  94. static constexpr int kBytePerRow = kHeadDim * sizeof(Element);
  95. static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
  96. static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
  97. static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
  98. // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory.
  99. // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma.
  100. // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension
  101. // changes the layout.
  102. using SmemLayoutAtomQdO = decltype(
  103. composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
  104. Layout<Shape<_8, Int<kBlockKGmem>>,
  105. Stride<Int<kBlockKGmem>, _1>>{}));
  106. using SmemLayoutQ =
  107. decltype(tile_to_shape(SmemLayoutAtomQdO{},
  108. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  109. using SmemLayoutdO =
  110. decltype(tile_to_shape(SmemLayoutAtomQdO{},
  111. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages_dO>{})));
  112. using SmemLayoutAtomKV = decltype(
  113. composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
  114. // TODO: FA2 has a slightly different layout, does it matter?
  115. Layout<Shape<_8, Int<kBlockKGmem>>,
  116. Stride<Int<kBlockKGmem>, _1>>{}));
  117. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomKV{}, select<1, 2>(TileShape_MNK{})));
  118. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomKV{}, select<1, 2>(TileShape_MNK{})));
  119. // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
  120. static constexpr int kPBlockN = kBlockN % 64 == 0 ? 64 : (kBlockN % 32 == 0 ? 32 : 16);
  121. static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
  122. // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
  123. static constexpr int kSwizzlePdS = 3;
  124. using SmemLayoutAtomPdS = decltype(
  125. composition(Swizzle<kSwizzlePdS, kSwizzleBase, kSwizzleBase>{},
  126. Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,
  127. Stride<Int<kPBlockN>, _1>>{}));
  128. using SmemLayoutPdS = decltype(tile_to_shape(
  129. SmemLayoutAtomPdS{},
  130. make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
  131. // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds,
  132. // it's still a valid smem address.
  133. using SmemLayoutLSE = cute::Layout<cute::Shape<Int<kBlockM>, Int<kStages>>, cute::Stride<_1, Int<cute::round_up(kBlockM, 64)>>>;
  134. using SmemLayoutLSEMma = std::conditional_t<
  135. SdP_swapAB,
  136. cute::Layout<cute::Shape<Int<kBlockN>, Int<kBlockM>, Int<kStages>>, cute::Stride<_0, _1, Int<cute::round_up(kBlockM, 64)>>>,
  137. cute::Layout<cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kStages>>, cute::Stride<_1, _0, Int<cute::round_up(kBlockM, 64)>>>
  138. >;
  139. // Note this is the transpose in terms of the view, not in terms of memory.
  140. using SmemLayoutQt =
  141. decltype(cute::composition(SmemLayoutQ{},
  142. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  143. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  144. using SmemLayoutdOt =
  145. decltype(cute::composition(SmemLayoutdO{},
  146. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages_dO>{}),
  147. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  148. using SmemLayoutKt =
  149. decltype(cute::composition(SmemLayoutK{},
  150. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  151. make_stride(Int<kBlockN>{}, _1{}))));
  152. using SmemLayoutPdSt =
  153. decltype(cute::composition(SmemLayoutPdS{},
  154. make_layout(make_shape(Int<kBlockN>{}, Int<kBlockM>{}),
  155. make_stride(Int<kBlockM>{}, _1{}))));
  156. // Thread layout, 256 or 384 threads per row
  157. using R2SLayoutAtomdQaccum = Layout<Shape<Int<NumMmaThreads>>>;
  158. using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},
  159. Layout<Shape < _1>>{})); // Val layout, 1 vals per store
  160. using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, Element>;
  161. using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, Element>;
  162. // For the case where the N dimension of MmaSdP is divisible by 8 but not by 16
  163. using SmemCopyAtomHalf = Copy_Atom<SM75_U32x2_LDSM_N, Element>;
  164. // For the case where the N dimension of MmadQ is divisible by 8 but not by 16
  165. using SmemCopyAtomTransposedHalf = Copy_Atom<SM75_U16x4_LDSM_T, Element>;
  166. // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt.
  167. // If PdS_major is MN, then we need to "transpose" the write.
  168. // TODO: check this write
  169. using R2SCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
  170. // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
  171. // from the same address by the same threadblock. This is slightly faster.
  172. using GmemCopyStruct = std::conditional_t<
  173. Has_cp_async,
  174. SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>,
  175. AutoVectorizingCopyWithAssumedAlignment<128>
  176. >;
  177. using GmemCopyAtom = Copy_Atom<GmemCopyStruct, Element>;
  178. static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
  179. static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow");
  180. using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  181. Stride<Int<kGmemThreadsPerRow>, _1>>;
  182. using GmemTiledCopyQKV = decltype(
  183. make_tiled_copy(GmemCopyAtom{},
  184. GmemLayoutAtom{},
  185. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per read
  186. using GmemCopyAtomLSE = Copy_Atom<GmemCopyStruct, float>;
  187. using GmemLayoutAtomLSE = Layout<Shape<Int<NumMmaThreads>>>;
  188. using GmemTiledCopyLSE = decltype(make_tiled_copy(GmemCopyAtomLSE{}, GmemLayoutAtomLSE{},
  189. Layout<Shape<_4>>{})); // Val layout, 4 vals per store
  190. // So that we don't have to check if we overshot kBlockM when we load Q
  191. // static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);
  192. using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
  193. using StrideQKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
  194. using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
  195. using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
  196. using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
  197. using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
  198. // These are tuned for speed. They don't affect correctness.
  199. // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64
  200. // this helps quite a bit to not have to do causal masking for most of the iterations.
  201. // For hdim 192, separating masking iterations results in register spills.
  202. // static constexpr bool SeparateMaskingIterations = kHeadDim <= 64;
  203. static constexpr bool SeparateMaskingIterations = false;
  204. // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then
  205. // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each
  206. // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep
  207. // statistic for 2 rows.
  208. // static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64;
  209. // static constexpr bool ShuffledPsum = SdP_swapAB && kHeadDim <= 64;
  210. static constexpr bool ShuffleLSE = SdP_swapAB && false;
  211. static constexpr bool ShuffledPsum = SdP_swapAB && false;
  212. static constexpr bool Share_QV_Smem = V_in_regs;
  213. using SmemP_t = std::conditional_t<Mma_dKV_is_RS, cute::array<Element, 0>, cute::array_aligned<Element, cute::cosize_v<SmemLayoutPdS>>>;
  214. struct TensorStorageSharedQV : cute::aligned_struct<128> {
  215. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  216. union {
  217. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  218. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  219. };
  220. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  221. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_lse;
  222. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_dpsum;
  223. SmemP_t smem_p;
  224. cute::array_aligned<Element, cute::cosize_v<SmemLayoutPdS>> smem_ds;
  225. };
  226. struct TensorStorageSeparateQV : cute::aligned_struct<128> {
  227. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  228. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  229. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  230. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  231. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_lse;
  232. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_dpsum;
  233. SmemP_t smem_p;
  234. cute::array_aligned<Element, cute::cosize_v<SmemLayoutPdS>> smem_ds;
  235. };
  236. using TensorStorage = std::conditional_t<Share_QV_Smem, TensorStorageSharedQV, TensorStorageSeparateQV>;
  237. // Host side kernel arguments
  238. struct Arguments {
  239. Element const* const ptr_Q;
  240. ShapeQKV const shape_Q;
  241. StrideQKV const stride_Q;
  242. Element const* const ptr_K;
  243. ShapeQKV const shape_K;
  244. StrideQKV const stride_K;
  245. Element const* const ptr_V;
  246. StrideQKV const stride_V;
  247. Element const* const ptr_dO;
  248. StrideQKV const stride_dO;
  249. ElementAccum* const ptr_dQaccum;
  250. ShapedQaccum const shape_dQaccum;
  251. StridedQaccum const stride_dQaccum;
  252. float const* const ptr_LSE_log2;
  253. ShapeLSE const shape_LSE;
  254. StrideLSE const stride_LSE_log2;
  255. float const* const ptr_dPsum;
  256. StrideLSE const stride_dPsum;
  257. float const softmax_scale;
  258. int const window_size_left, window_size_right, sink_token_length;
  259. float const softcap_val;
  260. int const num_batch;
  261. int* const dq_semaphore;
  262. int const* const cu_seqlens_q = nullptr;
  263. int const* const cu_seqlens_k = nullptr;
  264. int const* const seqused_q = nullptr;
  265. int const* const seqused_k = nullptr;
  266. };
  267. // Device side kernel params
  268. struct Params {
  269. Element const* const ptr_Q;
  270. ShapeQKV const shape_Q;
  271. StrideQKV const stride_Q;
  272. Element const* const ptr_K;
  273. ShapeQKV const shape_K;
  274. StrideQKV const stride_K;
  275. Element const* const ptr_V;
  276. StrideQKV const stride_V;
  277. Element const* const ptr_dO;
  278. StrideQKV const stride_dO;
  279. ElementAccum* const ptr_dQaccum;
  280. ShapedQaccum const shape_dQaccum;
  281. StridedQaccum stride_dQaccum;
  282. cutlass::FastDivmod qhead_per_khead_divmod;
  283. float const* const ptr_LSE_log2;
  284. ShapeLSE const shape_LSE;
  285. StrideLSE const stride_LSE_log2;
  286. float const* const ptr_dPsum;
  287. StrideLSE const stride_dPsum;
  288. float const softmax_scale, softmax_scale_log2;
  289. int const window_size_left, window_size_right, sink_token_length;
  290. float const softcap_val;
  291. int const num_batch;
  292. int *const dq_semaphore;
  293. int const *const cu_seqlens_q = nullptr;
  294. int const *const cu_seqlens_k = nullptr;
  295. int const *const seqused_q = nullptr;
  296. int const *const seqused_k = nullptr;
  297. };
  298. static Params
  299. to_underlying_arguments(Arguments const& args) {
  300. if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); }
  301. // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val.
  302. // Right after this, we multiply by log2(e) before applying exp2.
  303. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val
  304. // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e)
  305. // (assigning it to params.softmax_scale_log2).
  306. // In the backward, we need to multiply by
  307. // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale.
  308. // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale
  309. // (the original softmax_scale) at the end.
  310. return {args.ptr_Q, args.shape_Q, args.stride_Q,
  311. args.ptr_K, args.shape_K, args.stride_K,
  312. args.ptr_V, args.stride_V,
  313. args.ptr_dO, args.stride_dO,
  314. args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum,
  315. cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
  316. args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum,
  317. args.softmax_scale,
  318. !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E),
  319. args.window_size_left, args.window_size_right, args.sink_token_length,
  320. !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val,
  321. args.num_batch, args.dq_semaphore,
  322. args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k};
  323. }
  324. CUTLASS_DEVICE
  325. cute::tuple<int, int> get_m_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info,
  326. int n_block, int bidb) {
  327. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  328. int const seqlen_q = seqlen_info.seqlen_q;
  329. int const seqlen_k = seqlen_info.seqlen_k;
  330. int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
  331. if constexpr (Is_local) {
  332. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  333. if (n_block >= cute::ceil_div(params.sink_token_length, kBlockN)) {
  334. m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM));
  335. }
  336. }
  337. int m_block_min = 0;
  338. if constexpr (Is_causal || Is_local) {
  339. m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM);
  340. }
  341. return {m_block_min, m_block_max};
  342. }
  343. template <typename SharedStorage, typename FrgTensordKV>
  344. CUTLASS_DEVICE bool
  345. mma(Params const& params,
  346. FrgTensordKV& tdKrdK,
  347. FrgTensordKV& tdVrdV,
  348. int thread_idx,
  349. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  350. SharedStorage& shared_storage
  351. ) {
  352. static_assert(is_rmem<FrgTensordKV>::value, "dK and dV tensor must be rmem resident.");
  353. int n_block = get<0>(block_coord);
  354. int bidh = get<1>(block_coord);
  355. int bidb = get<2>(block_coord);
  356. SeqlenInfo_t seqlen_info{
  357. bidb, get<0>(params.shape_Q), size<0>(params.shape_K),
  358. params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k
  359. };
  360. auto m_block_min_max = get_m_block_min_max(params, seqlen_info, n_block, bidb);
  361. int const m_block_min = get<0>(m_block_min_max);
  362. int const m_block_max = get<1>(m_block_min_max);
  363. // It's possible to have m_block_max <= m_block_min. Exit early
  364. if constexpr (Is_causal || Is_local || Varlen) {
  365. if (m_block_max <= m_block_min) { return false; }
  366. }
  367. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});
  368. Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{});
  369. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  370. Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});
  371. Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{});
  372. Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{});
  373. Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{});
  374. Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{});
  375. Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{});
  376. Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{});
  377. Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{});
  378. Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{});
  379. Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{});
  380. Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{});
  381. Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{});
  382. bool const is_varlen_q = Varlen && params.cu_seqlens_q;
  383. bool const is_varlen_k = Varlen && params.cu_seqlens_k;
  384. int bidh_kv = params.qhead_per_khead_divmod.divide(bidh);
  385. Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q), params.shape_Q, params.stride_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
  386. Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_Q, params.stride_dO)(_, _, bidh, !is_varlen_q ? bidb : 0);
  387. Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);
  388. Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);
  389. Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0);
  390. Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0);
  391. Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.ptr_dQaccum)),
  392. params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen_q ? bidb : 0);
  393. Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
  394. Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
  395. Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
  396. Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
  397. Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _)
  398. Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _)
  399. Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(_)); // (M * K, _)
  400. GmemTiledCopyQKV gmem_tiled_copy_QKV;
  401. auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx);
  402. auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{}); // For index calculation
  403. GmemTiledCopyLSE gmem_tiled_copy_lse;
  404. auto gmem_thr_copy_lse = gmem_tiled_copy_lse.get_thread_slice(thread_idx);
  405. R2STiledCopydQaccum r2s_tiled_copy_dQaccum;
  406. auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
  407. Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
  408. Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
  409. Tensor tdOgdO = gmem_thr_copy_QKV.partition_S(gdO);
  410. Tensor tdOsdO = gmem_thr_copy_QKV.partition_D(sdO);
  411. Tensor tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE);
  412. Tensor tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE);
  413. Tensor tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum);
  414. Tensor tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum);
  415. // We can reuse r2s_thr_copy_dQaccum for this partitioning
  416. Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum);
  417. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); }
  418. TiledMmaSdP tiled_mma_SdP;
  419. TiledMmadKV tiled_mma_dKV;
  420. TiledMmadQ tiled_mma_dQ;
  421. auto thr_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx);
  422. auto thr_mma_dKV = tiled_mma_dKV.get_thread_slice(thread_idx);
  423. auto thr_mma_dQ = tiled_mma_dQ.get_thread_slice(thread_idx);
  424. // Allocate "fragments/descriptors"
  425. // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda,
  426. // because some partition_fragment_A/B don't compile.
  427. // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function
  428. Tensor tdPrV = mma_partition_fragment_AB</*A=*/SdP_swapAB>(thr_mma_SdP, sV);
  429. // Copy Atom retiling
  430. auto smem_copy_atom_SdP_B = cute::conditional_return<MmaSdPEvenN>(SmemCopyAtom{}, SmemCopyAtomHalf{});
  431. auto smem_tiled_copy_QdO = cute::conditional_return<!SdP_swapAB>(make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_SdP), make_tiled_copy_B(smem_copy_atom_SdP_B, tiled_mma_SdP));
  432. auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(thread_idx);
  433. Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ);
  434. Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO);
  435. auto smem_tiled_copy_KV = cute::conditional_return<!SdP_swapAB>(make_tiled_copy_B(smem_copy_atom_SdP_B, tiled_mma_SdP), make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_SdP));
  436. auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(thread_idx);
  437. Tensor tSsK = smem_thr_copy_KV.partition_S(sK);
  438. Tensor tdPsV = smem_thr_copy_KV.partition_S(sV);
  439. auto r2s_tiled_copy_PdS = make_tiled_copy_C(R2SCopyAtomPdS{}, tiled_mma_SdP);
  440. auto r2s_thr_copy_PdS = r2s_tiled_copy_PdS.get_thread_slice(thread_idx);
  441. Tensor tPsP = r2s_thr_copy_PdS.partition_D(cute::conditional_return<!SdP_swapAB>(sP, sPt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  442. Tensor tdSsdS = r2s_thr_copy_PdS.partition_D(cute::conditional_return<!SdP_swapAB>(sdS, sdSt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  443. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(r2s_thr_copy_PdS); print(sP); printf("\n"); print(sPt); printf("\n"); print(tPsP); printf("\n"); print(tdSsdS); printf("\n"); }
  444. auto smem_copy_atom_dKV_B = cute::conditional_return<MmadKVEvenN>(SmemCopyAtomTransposed{}, SmemCopyAtomTransposedHalf{});
  445. auto smem_tiled_copy_PdSt = cute::conditional_return<!dKV_swapAB>(make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dKV), make_tiled_copy_B(smem_copy_atom_dKV_B, tiled_mma_dKV));
  446. auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(thread_idx);
  447. Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt);
  448. Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt);
  449. auto smem_tiled_copy_QdOt = cute::conditional_return<!dKV_swapAB>(make_tiled_copy_B(smem_copy_atom_dKV_B, tiled_mma_dKV), make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dKV));
  450. auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(thread_idx);
  451. Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt);
  452. Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt);
  453. auto smem_tiled_copy_dS = cute::conditional_return<!dQ_swapAB>(
  454. make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_dQ),
  455. make_tiled_copy_B(cute::conditional_return<MmadQEvenN>(SmemCopyAtom{}, SmemCopyAtomHalf{}), tiled_mma_dQ));
  456. auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(thread_idx);
  457. Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS);
  458. auto smem_tiled_copy_Kt = cute::conditional_return<!dQ_swapAB>(
  459. make_tiled_copy_B(cute::conditional_return<MmadQEvenN>(SmemCopyAtomTransposed{}, SmemCopyAtomTransposedHalf{}), tiled_mma_dQ),
  460. make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dQ));
  461. auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(thread_idx);
  462. Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt);
  463. // thr_mma_SdP.partition_C(sLSEMma) has shape (MMA=4, MMA_M, MMA_N, PIPE), we only take the col indices
  464. // or row indices, depending on whether SdP_swapAB.
  465. Tensor tSsLSEMma = logical_divide(thr_mma_SdP.partition_C(sLSEMma), Shape<_2>{}); // (2, 2, MMA_M, MMA_N, PIPE)
  466. Tensor tSsLSE = group_modes<0, 2>(cute::conditional_return<!SdP_swapAB>(
  467. tSsLSEMma(make_coord(_0{}, _), _, _0{}, _), // (2, MMA_M, PIPE)
  468. tSsLSEMma(make_coord(_, _0{}), _0{}, _, _))); // (2, MMA_N, PIPE)
  469. Tensor tSsdPsumMma = logical_divide(thr_mma_SdP.partition_C(sdPsumMma), Shape<_2>{});
  470. Tensor tSsdPsum = group_modes<0, 2>(cute::conditional_return<!SdP_swapAB>(
  471. tSsdPsumMma(make_coord(_0{}, _), _, _0{}, _), // (2, MMA_M, PIPE)
  472. tSsdPsumMma(make_coord(_, _0{}), _0{}, _, _))); // (2, MMA_N, PIPE)
  473. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); }
  474. // If we want to split the stats among the 8 threads that share the same rows.
  475. static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tSsLSE))::value, 8);
  476. // Predicates
  477. Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
  478. Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);
  479. Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ);
  480. Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
  481. #pragma unroll
  482. for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); }
  483. Tensor cLSE = cute::make_identity_tensor(select<0>(TileShape_MNK{}));
  484. Tensor tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE);
  485. int const seqlen_q = seqlen_info.seqlen_q;
  486. int const seqlen_k = seqlen_info.seqlen_k;
  487. flash::Mask<kBlockM, kBlockN, false /*PackGQA*/, TiledMmaSdP, SdP_swapAB> mask(
  488. thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length,
  489. params.qhead_per_khead_divmod
  490. );
  491. {
  492. Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN)
  493. Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
  494. Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN)
  495. Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
  496. // Predicates
  497. Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{}));
  498. Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);
  499. Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV);
  500. Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
  501. #pragma unroll
  502. for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); }
  503. // Do we need bound check to make sure the row doesn't go above kBlockN
  504. static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;
  505. // static_assert(EvenN); // It simplifies the loading of K and V
  506. // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit
  507. // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time.
  508. // int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN
  509. // ? seqlen_info.seqlen_k - n_block * kBlockN
  510. // : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN));
  511. // // Need Clear_OOB_MN to be true here since the gemm will sum over the kBlockN dimension
  512. // flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clear_OOB_K=*/true>(
  513. // gmem_tiled_copy_QKV, tVgV, tVsV, t0KVcKV, tKVpKV, seqlenk_row_limit);
  514. int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{}));
  515. #pragma unroll
  516. for (int m = 0; m < size<1>(tVsV); ++m) {
  517. // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked
  518. if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) {
  519. bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit;
  520. #pragma unroll
  521. for (int k = 0; k < size<2>(tVsV); ++k) {
  522. cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV(_, m, k), tVsV(_, m, k));
  523. }
  524. }
  525. }
  526. if constexpr (V_in_regs) { flash::cp_async_fence(); }
  527. // flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clear_OOB_K=*/true>(
  528. // gmem_tiled_copy_QKV, tKgK, tKsK, t0KVcKV, tKVpKV, seqlenk_row_limit);
  529. #pragma unroll
  530. for (int m = 0; m < size<1>(tKsK); ++m) {
  531. if (EvenN || m < size<1>(tKsK) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) {
  532. bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit;
  533. #pragma unroll
  534. for (int k = 0; k < size<2>(tKsK); ++k) {
  535. cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tKgK(_, m, k), tKsK(_, m, k));
  536. }
  537. }
  538. }
  539. flash::cp_async_fence();
  540. }
  541. if constexpr (V_in_regs) {
  542. flash::cp_async_wait<1>();
  543. __syncthreads();
  544. Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV);
  545. Tensor tdPsV_copy_view = smem_thr_copy_KV.partition_S(sV);
  546. cute::copy(smem_tiled_copy_KV, tdPsV_copy_view, tdPrV_copy_view);
  547. __syncthreads(); // Sync to avoid loading Q to smem_q, which overlaps with smem_v
  548. }
  549. // Do we need bound check to make sure the row doesn't go above kBlockM
  550. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  551. static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;
  552. auto load_Q_LSE = [&] (int const m_block, int const smem_pipe_write) {
  553. // if (cute::thread0()) { printf("Inside load_Q_LSE, m_block = %d, smem_pipe_write = %d\n", m_block, smem_pipe_write); }
  554. Tensor tQsQ_cur = tQsQ(_, _, _, smem_pipe_write);
  555. Tensor tQgQ_cur = tQgQ(_, _, _, m_block);
  556. // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit
  557. // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time.
  558. // int const seqlenq_row_limit = -int(get<0>(tQcQ(_0{}, _0{}, _0{}))) + (EvenM
  559. // ? seqlen_info.seqlen_q - m_block * kBlockM
  560. // : std::min(seqlen_info.seqlen_q - m_block * kBlockM, kBlockM));
  561. // Need Clear_OOB_MN to be true here since the gemm will sum over the kBlockM dimension
  562. // flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clear_OOB_K=*/true>(
  563. // gmem_tiled_copy_QKV, tQgQ(_, _, _, m_block), tQsQ_cur, t0QcQ, tQpQ, seqlenq_row_limit);
  564. int const seqlenq_row_limit = seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{}));
  565. #pragma unroll
  566. for (int m = 0; m < size<1>(tQsQ); ++m) {
  567. // If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked
  568. if (EvenM || m < size<1>(tQsQ) - 1 || get<0>(tQcQ(_0{}, m, _0{})) < kBlockM) {
  569. bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit;
  570. #pragma unroll
  571. for (int k = 0; k < size<2>(tQsQ); ++k) {
  572. cute::copy(gmem_tiled_copy_QKV.with(tQpQ(k) && predicate_m), tQgQ_cur(_, m, k), tQsQ_cur(_, m, k));
  573. }
  574. }
  575. }
  576. Tensor tLSEgLSE_cur = tLSEgLSE(_, _, m_block);
  577. Tensor tLSEsLSE_cur = tLSEsLSE(_, _, smem_pipe_write);
  578. // We made sure LSE length is padded so we read `kBlockM` elements so that all
  579. // elements in sLSE are filled. Without this we might have uninitialized sLSE values.
  580. #pragma unroll
  581. for (int m = 0; m < size<1>(tLSEsLSE); ++m) {
  582. if (get<0>(tLSEcLSE(_0{}, m)) < kBlockM) {
  583. cute::copy(gmem_tiled_copy_lse, tLSEgLSE_cur(_, m), tLSEsLSE_cur(_, m));
  584. }
  585. }
  586. };
  587. auto load_dO_dPsum = [&] (int const m_block, int const smem_pipe_write) {
  588. // if (cute::thread0()) { printf("Inside load_dO_dPsum, m_block = %d, smem_pipe_write = %d\n", m_block, smem_pipe_write); }
  589. Tensor tdOsdO_cur = tdOsdO(_, _, _, smem_pipe_write);
  590. Tensor tdOgdO_cur = tdOgdO(_, _, _, m_block);
  591. // int const seqlenq_row_limit = -int(get<0>(tQcQ(_0{}, _0{}, _0{}))) + (EvenM
  592. // ? seqlen_info.seqlen_q - m_block * kBlockM
  593. // : std::min(seqlen_info.seqlen_q - m_block * kBlockM, kBlockM));
  594. // flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clear_OOB_K=*/true>(
  595. // gmem_tiled_copy_QKV, tdOgdO(_, _, _, m_block), tdOsdO_cur, t0QcQ, tQpQ, seqlenq_row_limit);
  596. int const seqlenq_row_limit = seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{}));
  597. #pragma unroll
  598. for (int m = 0; m < size<1>(tdOsdO); ++m) {
  599. // If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked
  600. if (EvenM || m < size<1>(tdOsdO) - 1 || get<0>(tQcQ(_0{}, m, _0{})) < kBlockM) {
  601. bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit;
  602. #pragma unroll
  603. for (int k = 0; k < size<2>(tdOsdO); ++k) {
  604. cute::copy(gmem_tiled_copy_QKV.with(tQpQ(k) && predicate_m), tdOgdO_cur(_, m, k), tdOsdO_cur(_, m, k));
  605. }
  606. }
  607. }
  608. Tensor tLSEgdPsum_cur = tLSEgdPsum(_, _, m_block);
  609. Tensor tLSEsdPsum_cur = tLSEsdPsum(_, _, smem_pipe_write);
  610. #pragma unroll
  611. for (int m = 0; m < size<1>(tLSEsdPsum); ++m) {
  612. if (get<0>(tLSEcLSE(_0{}, m)) < kBlockM) {
  613. cute::copy(gmem_tiled_copy_lse, tLSEgdPsum_cur(_, m), tLSEsdPsum_cur(_, m));
  614. }
  615. }
  616. };
  617. int m_block = m_block_min;
  618. // Note, using the for_each() function here to ensure `stage` is of type Int<x>.
  619. for_each(make_int_sequence<kStages>{}, [&] (auto stage) {
  620. static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0;
  621. static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1;
  622. if constexpr (!Is_last_stage || kStages == 1) {
  623. if (Is_first_stage || m_block + stage < m_block_max) {
  624. load_Q_LSE(m_block + stage, stage);
  625. }
  626. }
  627. // We want the fence outside the if statement to have a fixed number of cp.async commits.
  628. // so that we can wait with the correct number of outstanding commits.
  629. cute::cp_async_fence();
  630. if constexpr (stage < kStages_dO) {
  631. if (Is_first_stage || m_block + stage < m_block_max) {
  632. load_dO_dPsum(m_block + stage, stage);
  633. }
  634. cute::cp_async_fence();
  635. }
  636. });
  637. int smem_pipe_read = 0, smem_pipe_read_do = 0, smem_pipe_write = kStages - 1, smem_pipe_write_do = 0;
  638. auto load_Q_next = [&] {
  639. // if (cute::thread0()) { printf("m_block = %d, m_block_max = %d, smem_pipe_write = %d\n", m_block, m_block_max, smem_pipe_write); }
  640. if (m_block + (kStages > 1 ? kStages - 1 : 1) < m_block_max) {
  641. load_Q_LSE(m_block + (kStages > 1 ? kStages - 1 : 1), kStages > 1 ? smem_pipe_write : 0);
  642. }
  643. cute::cp_async_fence();
  644. };
  645. auto load_dO_next = [&] {
  646. // int smem_pipe_write_do_cur = Q_dO_same_stages ? smem_pipe_write : smem_pipe_write_do;
  647. if (m_block + kStages_dO < m_block_max) {
  648. // load_dO_dPsum(m_block + kStages_dO, kStages_dO > 1 ? smem_pipe_write_do_cur : 0);
  649. load_dO_dPsum(m_block + kStages_dO, kStages_dO > 1 ? smem_pipe_write_do : 0);
  650. }
  651. cute::cp_async_fence();
  652. };
  653. clear(tdKrdK);
  654. clear(tdVrdV);
  655. auto bwd_step = [&](int m_block, auto mask_fn) {
  656. Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<!SdP_swapAB ? 0 : 1, !SdP_swapAB ? 1 : 0>(TileShape_MNK{}));
  657. clear(tSrS);
  658. flash::cp_async_wait<(kStages > 1) ? 1 : 0>();
  659. __syncthreads();
  660. Tensor tSrQ = mma_partition_fragment_AB</*A=*/!SdP_swapAB>(thr_mma_SdP, sQ(_, _, _0{}));
  661. Tensor tSrK = mma_partition_fragment_AB</*A=*/SdP_swapAB>(thr_mma_SdP, sK);
  662. // if (cute::thread0()) { print(tiled_mma_SdP); print(tSrS); printf("\n"); print(tSrQ); printf("\n"); print(tSrK); printf("\n"); print(tSsQ); printf("\n"); print(tSsK); printf("\n"); }
  663. flash::gemm_sm80<false /*A_in_regs*/, false /*B_in_regs*/, SdP_swapAB>(
  664. tSrS, tSrQ, tSrK, tSsQ(_, _, _, kStages > 1 ? smem_pipe_read : 0), tSsK,
  665. tiled_mma_SdP, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV, nullptr /*hook*/);
  666. Tensor tLSErLSE = cute::conditional_return<!ShuffleLSE>(make_fragment_like(tSsLSE(_, _0{})), make_tensor<ElementAccum>(Int<kStatsPerThread>{}));
  667. if constexpr (!ShuffleLSE) {
  668. cute::copy(tSsLSE(_, kStages > 1 ? smem_pipe_read : 0), tLSErLSE);
  669. } else {
  670. #pragma unroll
  671. for (int i = 0; i < kStatsPerThread; ++i) {
  672. // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values
  673. tLSErLSE(i) = tSsLSE((thread_idx % 32) / 4 + i * 8, kStages > 1 ? smem_pipe_read : 0);
  674. }
  675. }
  676. if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); }
  677. // Reshape tSrS from (4, MMA_N, MMA_M) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
  678. Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SdP_swapAB>(tSrS.layout()));
  679. // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh
  680. // if (cute::thread0()) { print_tensor(scores); }
  681. auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }();
  682. mask_fn(tSrS, m_block);
  683. #pragma unroll
  684. for (int mi = 0; mi < size<0>(scores); ++mi) {
  685. float const lse_scaled = [&] {
  686. if constexpr (!ShuffleLSE) return tLSErLSE(mi);
  687. else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4));
  688. }();
  689. #pragma unroll
  690. for (int ni = 0; ni < size<1>(scores); ++ni) {
  691. scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled);
  692. }
  693. }
  694. Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<!SdP_swapAB ? 0 : 1, !SdP_swapAB ? 1 : 0>(TileShape_MNK{}));
  695. clear(tdPrdP);
  696. int smem_pipe_read_do_cur = Q_dO_same_stages ? smem_pipe_read : smem_pipe_read_do;
  697. flash::cp_async_wait<(kStages_dO > 1) ? 1 : 0>();
  698. __syncthreads();
  699. auto hook = cute::conditional_return<(kStages > 1)>(load_Q_next, nullptr);
  700. Tensor tdPrdO = mma_partition_fragment_AB</*A=*/!SdP_swapAB>(thr_mma_SdP, sdO(_, _, _0{}));
  701. Tensor tdPrV_cur = cute::conditional_return<V_in_regs>(tdPrV, mma_partition_fragment_AB</*A=*/SdP_swapAB>(thr_mma_SdP, sV));
  702. flash::gemm_sm80<false /*A_in_regs*/, V_in_regs, SdP_swapAB>(
  703. tdPrdP, tdPrdO, tdPrV_cur, tdPsdO(_, _, _, kStages_dO > 1 ? smem_pipe_read_do_cur : 0), tdPsV,
  704. tiled_mma_SdP, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV, hook);
  705. Tensor tLSErdPsum = cute::conditional_return<!ShuffledPsum>(make_fragment_like(tSsdPsum(_, _0{})), make_tensor<ElementAccum>(Int<kStatsPerThread>{}));
  706. if constexpr (!ShuffledPsum) {
  707. cute::copy(tSsdPsum(_, kStages_dO > 1 ? smem_pipe_read_do_cur : 0), tLSErdPsum);
  708. } else {
  709. #pragma unroll
  710. for (int i = 0; i < kStatsPerThread; ++i) {
  711. tLSErdPsum(i) = tSsdPsum((thread_idx % 32) / 4 + i * 8, kStages_dO > 1 ? smem_pipe_read_do_cur : 0);
  712. }
  713. }
  714. // Reshape tdPrdP from (4, MMA_N, MMA_M) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
  715. Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
  716. #pragma unroll
  717. for (int mi = 0; mi < size<0>(dS); ++mi) {
  718. float const dP_sum_cur = [&] {
  719. if constexpr (!ShuffledPsum) return tLSErdPsum(mi);
  720. else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4));
  721. }();
  722. #pragma unroll
  723. for (int ni = 0; ni < size<1>(dS); ++ni) {
  724. dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur);
  725. if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); }
  726. }
  727. }
  728. // if (cute::thread0()) { print_tensor(dS); }
  729. // Convert scores from fp32 to fp16/bf16
  730. Tensor rP = make_tensor_like<Element>(tSrS);
  731. flash::convert_type_out(tSrS, rP);
  732. if constexpr (!Mma_dKV_is_RS) {
  733. Tensor tPaP = r2s_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N)
  734. cute::copy(r2s_tiled_copy_PdS, tPaP, tPsP);
  735. }
  736. Tensor rdS = make_tensor_like<Element>(tdPrdP);
  737. flash::convert_type_out(tdPrdP, rdS);
  738. if constexpr (!Mma_dKV_is_RS) { __syncthreads(); } // Make sure P is written
  739. // For hdim 64, It's faster to write to smem_dS first before the dV gemm
  740. Tensor tdSadS = r2s_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
  741. cute::copy(r2s_tiled_copy_PdS, tdSadS, tdSsdS);
  742. Tensor tdVrdO = mma_partition_fragment_AB</*A=*/dKV_swapAB>(thr_mma_dKV, sdOt(_, _, _0{}));
  743. Tensor tdVsdO_cur = tdVsdOt(_, _, _, kStages_dO > 1 ? smem_pipe_read_do_cur : 0);
  744. if constexpr (Mma_dKV_is_RS) {
  745. Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<TiledMmadKV>(tSrS.layout()));
  746. flash::gemm_rs_sm80(tdVrdV, tdVrP, tdVrdO, tdVsdO_cur, tiled_mma_dKV, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
  747. } else {
  748. Tensor tdVrP = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(thr_mma_dKV, sPt);
  749. flash::gemm_sm80<false /*A_in_regs*/, false /*B_in_regs*/, /*SwapAB=*/dKV_swapAB>(
  750. tdVrdV, tdVrP, tdVrdO, tdVsPt, tdVsdO_cur,
  751. tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, nullptr);
  752. }
  753. // if (cute::thread0()) { print_tensor(tdVrdV); }
  754. __syncthreads(); // make sure sdS is written
  755. auto do_mma_dQ = [&] (auto hook) {
  756. Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
  757. clear(tdQrdQ);
  758. Tensor tdQrdS = mma_partition_fragment_AB</*A=*/!dQ_swapAB>(thr_mma_dQ, sdS);
  759. Tensor tdQrK = mma_partition_fragment_AB</*A=*/dQ_swapAB>(thr_mma_dQ, sKt);
  760. flash::gemm_sm80<false /*A_in_regs*/, false /*B_in_regs*/, /*SwapAB=*/dQ_swapAB>(
  761. tdQrdQ, tdQrdS, tdQrK, tdQsdS, tdQsKt, tiled_mma_dQ,
  762. // smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, load_dO_next);
  763. smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, hook);
  764. // if (cute::thread0()) { print_tensor(tdQrdQ); }
  765. // We can reuse r2s_thr_copy_dQaccum for this partitioning
  766. Tensor tdQrdQ_atomic = r2s_thr_copy_dQaccum.retile_S(tdQrdQ);
  767. Tensor tdQgdQaccum_atomic = tdQgdQaccum(_, _, m_block);
  768. static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic)));
  769. #pragma unroll
  770. for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }
  771. };
  772. // If kStages == 1, we want to do Mma_dK first so we can start loading Q for the next iteration
  773. if constexpr (kStages > 1) { do_mma_dQ(load_dO_next); }
  774. Tensor tdKrQ = mma_partition_fragment_AB</*A=*/dKV_swapAB>(thr_mma_dKV, sQt(_, _, _0{}));
  775. Tensor tdKsQ_cur = tdKsQt(_, _, _, kStages > 1 ? smem_pipe_read : 0);
  776. if constexpr (Mma_dKV_is_RS) {
  777. Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<TiledMmadKV>(tdPrdP.layout()));
  778. flash::gemm_rs_sm80(tdKrdK, tdKrdS, tdKrQ, tdKsQ_cur, tiled_mma_dKV, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
  779. } else {
  780. Tensor tdKrdS = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(thr_mma_dKV, sdSt);
  781. flash::gemm_sm80<false /*A_in_regs*/, false /*B_in_regs*/, /*SwapAB=*/dKV_swapAB>(
  782. tdKrdK, tdKrdS, tdKrQ, tdKsdSt, tdKsQ_cur,
  783. tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, cute::conditional_return<(kStages > 1)>(nullptr, load_dO_next));
  784. }
  785. if constexpr (kStages == 1) {
  786. __syncthreads();
  787. do_mma_dQ(load_Q_next);
  788. }
  789. // if (cute::thread0()) { print_tensor(tdKrdK); }
  790. smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0;
  791. smem_pipe_read_do = smem_pipe_read_do < kStages_dO - 1 ? smem_pipe_read_do + 1 : 0;
  792. smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0;
  793. smem_pipe_write_do = smem_pipe_write_do < kStages_dO - 1 ? smem_pipe_write_do + 1 : 0;
  794. };
  795. // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64
  796. // this helps quite a bit to not have to do causal masking for most of the iterations.
  797. if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) {
  798. auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
  799. int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1;
  800. CUTLASS_PRAGMA_NO_UNROLL
  801. for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) {
  802. bwd_step(m_block, mask_fn);
  803. }
  804. }
  805. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  806. int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations
  807. ? m_block_max
  808. : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM);
  809. auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal && !SeparateMaskingIterations, Is_local && !SeparateMaskingIterations>(tSrS, m_block, n_block); };
  810. CUTLASS_PRAGMA_NO_UNROLL
  811. for (; m_block < m_block_max_before_local_mask; ++m_block) {
  812. bwd_step(m_block, mask_fn);
  813. }
  814. if constexpr (Is_local && SeparateMaskingIterations) {
  815. auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, false /*Causal_mask*/, Is_local>(tSrS, m_block, n_block); };
  816. CUTLASS_PRAGMA_NO_UNROLL
  817. for (; m_block < m_block_max; ++m_block) {
  818. bwd_step(m_block, mask_fn);
  819. }
  820. }
  821. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); }
  822. #pragma unroll
  823. for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; }
  824. return true;
  825. }
  826. };
  827. } // namespace flash