mainloop_fwd_sm90_tma_gmma_ws.hpp 88 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include "cutlass/pipeline/pipeline.hpp"
  10. #include "cute/tensor.hpp"
  11. #include "cutlass/gemm/collective/builders/sm90_common.inl"
  12. #include "named_barrier.hpp"
  13. #include "seqlen.h"
  14. #include "mask.h"
  15. #include "pack_gqa.h"
  16. #include "paged_kv.h"
  17. #include "rotary.h"
  18. #include "utils.h"
  19. #include "sm90_pipeline_no_cluster.hpp"
  20. namespace flash {
  21. using namespace cute;
  22. template <int Stages, class ClusterShape_, class TileShape_MNK_, class Element_, class ElementAccum_, class ArchTag_,
  23. bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool PagedKV_, bool AppendKV_,
  24. bool Mma1_is_RS, bool IntraWGOverlap, bool PackGQA_, bool Split_, bool V_colmajor_>
  25. struct CollectiveMainloopFwdSm90 {
  26. static constexpr int kStages = Stages;
  27. using ClusterShape = ClusterShape_;
  28. using TileShape_MNK = TileShape_MNK_;
  29. using Element = Element_;
  30. using ElementAccum = ElementAccum_;
  31. using ArchTag = ArchTag_;
  32. static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;;
  33. static constexpr bool Is_causal = Is_causal_;
  34. static constexpr bool Is_local = Is_local_;
  35. static constexpr bool Has_softcap = Has_softcap_;
  36. static constexpr bool Varlen = Varlen_;
  37. static constexpr bool PagedKV = PagedKV_;
  38. static constexpr bool AppendKV = AppendKV_;
  39. static constexpr bool PackGQA = PackGQA_;
  40. static constexpr bool Split = Split_;
  41. static constexpr bool V_colmajor = V_colmajor_;
  42. static constexpr bool Transpose_V = Is_FP8 && !V_colmajor;
  43. static constexpr bool Use_TMA_Q = !PackGQA;
  44. static constexpr bool Use_TMA_KV = !PagedKV;
  45. static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1");
  46. static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported");
  47. using SeqlenInfo_t = flash::SeqlenInfoQKNewK<Varlen, AppendKV>;
  48. static_assert(ArchTag::kMinComputeCapability >= 90);
  49. static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 && !V_colmajor ? GMMA::Major::MN : GMMA::Major::K;
  50. static constexpr cute::GMMA::Major TmaMajorV = !V_colmajor ? GMMA::Major::MN : GMMA::Major::K;
  51. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  52. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  53. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  54. // Register bandwidth is actually a bottleneck so we don't want Q to be in registers.
  55. // Leaving this option here for reference.
  56. static constexpr bool Mma0_is_RS = false;
  57. // We can have Mma1 (P @ V) with P in smem in rmem to reduce register pressure at the cost of more smem.
  58. static_assert(!(!Mma1_is_RS && !IntraWGOverlap), "Mma1 must be RS if IntraWGOverlap is enabled");
  59. static_assert(!(!Mma1_is_RS && Is_FP8), "Mma1 must be RS if FP8");
  60. static_assert(!(!Mma1_is_RS && Transpose_V), "Mma1 must be RS if Transpose_V");
  61. using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
  62. using TiledMma0 = decltype(cute::make_tiled_mma(
  63. std::conditional_t<
  64. !Mma0_is_RS,
  65. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>()),
  66. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>())
  67. >{},
  68. AtomLayoutMNK{}));
  69. using TiledMma1 = decltype(cute::make_tiled_mma(
  70. std::conditional_t<
  71. !Mma1_is_RS,
  72. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum,
  73. decltype(select<0, 2, 1>(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()),
  74. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum,
  75. decltype(select<0, 2, 1>(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>())
  76. >{},
  77. AtomLayoutMNK{}));
  78. static constexpr int NumMmaThreads = size(TiledMma0{});
  79. static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup;
  80. static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0);
  81. static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup;
  82. static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
  83. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  84. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  85. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  86. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  87. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  88. using SmemLayoutK = decltype(tile_to_shape(
  89. SmemLayoutAtomK{},
  90. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  91. using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector<TmaMajorV, Element,
  92. decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  93. using SmemLayoutVt = decltype(tile_to_shape(
  94. SmemLayoutAtomVt{},
  95. make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{}),
  96. std::conditional_t<TmaMajorV == GMMA::Major::K, cute::Step<_1, _2, _3>, cute::Step<_2, _1, _3>>{}));
  97. using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector<MmaMajorV, Element,
  98. decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  99. using SmemLayoutVtMma = decltype(tile_to_shape(
  100. SmemLayoutAtomVtMma{},
  101. make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{}),
  102. std::conditional_t<MmaMajorV == GMMA::Major::K, cute::Step<_1, _2, _3>, cute::Step<_2, _1, _3>>{}));
  103. // Only used if we're using cp.async to load V
  104. using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  105. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  106. using SmemLayoutVCpAsync = decltype(tile_to_shape(
  107. SmemLayoutAtomVCpAsync{},
  108. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  109. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  110. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  111. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  112. using SmemCopyAtomP = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
  113. // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major.
  114. // For FP16/BF16 we don't do any transposing.
  115. static_assert(!Transpose_V || (kHeadDim % 32 == 0 && kBlockN % 32 == 0));
  116. static constexpr bool kHeadDim_multiple_64 = kHeadDim % 64 == 0;
  117. // Either kHeadDim is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose),
  118. // or we need kBlockN to be a multiple of 64 (in which case we use a block size of 32 x 64 for the transpose).
  119. static_assert(!Transpose_V || (kHeadDim_multiple_64 || kBlockN % 64 == 0));
  120. using LDSM_thread_shape = std::conditional_t<kHeadDim_multiple_64, Shape<_32, _4, _1, _1>, Shape<_16, _4, _1, _2>>;
  121. using LDSM_thread_stride = std::conditional_t<kHeadDim_multiple_64, Stride<_4, _1, _0, _0>, Stride<_4, _1, _0, _64>>;
  122. using LDSM_value_shape = Shape<_2, _2, _1, _4>;
  123. using LDSM_value_stride = Stride<_1, _2, _16, _4>;
  124. using LDSM_divide_shape = std::conditional_t<kHeadDim_multiple_64, Shape<_64, _8>, Shape<_32, _8>>;
  125. using S2RTiledCopyVt = decltype(make_tiled_copy(
  126. Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<LDSM_thread_shape, LDSM_thread_stride>{},
  127. Layout<LDSM_value_shape, LDSM_value_stride>{}));
  128. using STSM_thread_shape = std::conditional_t<kHeadDim_multiple_64, Shape<_8, _4, _4, _1>, Shape<_8, _4, _2, _2>>;
  129. using STSM_thread_stride = std::conditional_t<kHeadDim_multiple_64, Stride<_4, _1, _32, _0>, Stride<_4, _1, _32, _64>>;
  130. using STSM_value_shape = Shape<_1, _4, _2, _2>;
  131. using STSM_value_stride = Stride<_0, _1, _4, _8>;
  132. using STSM_divide_shape = Shape<_8, _16>;
  133. // These will not permute the columns of V (the kHeadDim dimension) but incur bank conflicts
  134. // so a little slower (e.g. 1150 TFLOPS for hdim 256 instead of 1200 TFLOPS).
  135. // Instead we will permute the cols of V, and un-permute the cols of O in the epilogue.
  136. // using STSM_value_shape = Shape<_2, _4, _1, _2>;
  137. // using STSM_value_stride = Stride<_4, _1, _0, _8>;
  138. // using STSM_divide_shape = Shape<_16, _16>;
  139. using R2STiledCopyV = decltype(make_tiled_copy(
  140. Copy_Atom<SM90_U32x4_STSM_N, Element>{}, Layout<STSM_thread_shape, STSM_thread_stride>{},
  141. Layout<STSM_value_shape, STSM_value_stride>{}));
  142. using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
  143. using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
  144. // We use CpAsync for K and V if PagedKV and AppendKV, since TMA doesn't work there
  145. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  146. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  147. // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each
  148. // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.
  149. // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved
  150. // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will
  151. // load twice from the same row.
  152. static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element);
  153. static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
  154. static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
  155. static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow");
  156. // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where
  157. // these threads share the same page table entry and share the work of computing pointers to paged K and paged V.
  158. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp");
  159. using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  160. Stride<Int<kGmemThreadsPerRow>, _1>>;
  161. // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication
  162. static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRow");
  163. using GmemTiledCopyAppendKV = decltype(
  164. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  165. GmemLayoutAtom{},
  166. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
  167. using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
  168. using StrideQK = cute::Stride<int64_t, _1, int64_t, int64_t>;
  169. using StrideV = std::conditional_t<!V_colmajor, StrideQK, cute::Stride<_1, int64_t, int64_t, int64_t>>;
  170. // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
  171. using ShapeQPacked = std::conditional_t<!PackGQA, ShapeQKV, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
  172. using StrideQPacked = std::conditional_t<!PackGQA, StrideQK, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t>>;
  173. using ShapePageTable = cute::Shape<int32_t, int32_t>; // (batch, max_num_pages_per_seq)
  174. using StridePageTable = cute::Stride<int64_t, _1>;
  175. using ShapeRotary = cute::Shape<int32_t, int32_t>; // (seqlen_ro, rotary_dim // 2)
  176. using StrideRotary = cute::Stride<int64_t, _1>;
  177. using StrideDescale = cute::Stride<int64_t, int64_t>;
  178. using TMA_Q = decltype(make_tma_copy_A_sm90(
  179. GmemTiledCopyQ{},
  180. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQK{}),
  181. SmemLayoutQ{},
  182. TileShape_MNK{},
  183. ClusterShape{}));
  184. using TMA_K = decltype(make_tma_copy_B_sm90(
  185. GmemTiledCopyKV{},
  186. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQK{}),
  187. take<0, 2>(SmemLayoutK{}),
  188. TileShape_MNK{},
  189. ClusterShape{})); // mcast along M mode for this N load, if any
  190. using TMA_V = decltype(make_tma_copy(
  191. GmemTiledCopyKV{},
  192. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})),
  193. take<0, 2>(SmemLayoutVt{}),
  194. select<2, 1>(TileShape_MNK{}),
  195. size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
  196. // Set the bytes transferred in this TMA transaction (may involve multiple issues)
  197. static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
  198. static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
  199. static constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v<Element> / 8);
  200. static_assert(TmaTransactionBytesK == TmaTransactionBytesV);
  201. using PipelineTmaAsync = std::conditional_t<CUTE_STATIC_V(size(ClusterShape{})) == 1, typename cutlass::PipelineTmaAsyncNoCluster<kStages>, typename cutlass::PipelineTmaAsync<kStages>>;
  202. using MainloopPipelineK = std::conditional_t<Use_TMA_KV, PipelineTmaAsync, typename cutlass::PipelineAsync<kStages>>;
  203. using MainloopPipelineV = std::conditional_t<!Transpose_V && Use_TMA_KV, PipelineTmaAsync, typename cutlass::PipelineAsync<kStages>>;
  204. using MainloopPipelineVt = std::conditional_t<Use_TMA_KV, PipelineTmaAsync, typename cutlass::PipelineAsync<kStages>>;
  205. // We always use TMA for K_new and V_new
  206. using MainloopPipelineKVNew = PipelineTmaAsync;
  207. using PipelineState = cutlass::PipelineState<kStages>;
  208. // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned
  209. // and have sQ being position_independent_swizzle_tensor.
  210. // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned.
  211. static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !AppendKV && !Mma0_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{});
  212. static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{});
  213. static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{});
  214. static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment");
  215. static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{});
  216. static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment");
  217. using SmemP_t = std::conditional_t<Mma1_is_RS, cute::array<Element, 0>, cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>, SmemAlignmentP>>;
  218. // Sometimes even with SmemP_t = cute::array<Element, 0>, putting it in the TensorStorage struct causes
  219. // smem size to go from 227KB to 228KB and we get "invalid argument".
  220. struct TensorStorageWithoutPNoTranspose : cute::aligned_struct<cute::max(SmemAlignmentQ, SmemAlignmentK, SmemAlignmentVtNoTranspose)> {
  221. cute::array_aligned<Element, cute::cosize_v<SmemLayoutVt>, SmemAlignmentVtNoTranspose> smem_v;
  222. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>, SmemAlignmentQ> smem_q;
  223. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>, SmemAlignmentK> smem_k;
  224. };
  225. struct TensorStorageWithPNoTranspose : cute::aligned_struct<cute::max(SmemAlignmentQ, SmemAlignmentK, SmemAlignmentVtNoTranspose, SmemAlignmentP)> {
  226. cute::array_aligned<Element, cute::cosize_v<SmemLayoutVt>, SmemAlignmentVtNoTranspose> smem_v;
  227. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>, SmemAlignmentQ> smem_q;
  228. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>, SmemAlignmentK> smem_k;
  229. SmemP_t smem_p;
  230. };
  231. using TensorStorageNoTranspose = std::conditional_t<Mma1_is_RS, TensorStorageWithoutPNoTranspose, TensorStorageWithPNoTranspose>;
  232. static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{});
  233. static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{});
  234. static_assert(SmemAlignmentVt >= 128 and SmemAlignmentV >= 128, "Require at least 128B alignment");
  235. struct TensorStorageTransposeV : cute::aligned_struct<cute::max(SmemAlignmentQ, SmemAlignmentK, SmemAlignmentV)> {
  236. cute::array_aligned<Element, cute::cosize_v<SmemLayoutVtMma>, SmemAlignmentV> smem_v;
  237. cute::array_aligned<Element, cute::cosize_v<SmemLayoutVt>, SmemAlignmentVt> smem_vt;
  238. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>, SmemAlignmentQ> smem_q;
  239. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>, SmemAlignmentK> smem_k;
  240. };
  241. using TensorStorage = std::conditional_t<!Transpose_V, TensorStorageNoTranspose, TensorStorageTransposeV>;
  242. // These are tuned for speed. They don't affect correctness.
  243. static constexpr bool UseSchedulerBarrier = IntraWGOverlap
  244. ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128)
  245. : NumMmaWarpGroups == 2;
  246. static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor);
  247. // Host side kernel arguments
  248. struct Arguments {
  249. Element const* const ptr_Q;
  250. ShapeQKV const shape_Q;
  251. StrideQK const stride_Q;
  252. Element* const ptr_K; // Not Element const* since we might append to KV cache in-place
  253. ShapeQKV const shape_K;
  254. StrideQK const stride_K;
  255. Element* const ptr_V;
  256. StrideV const stride_V;
  257. Element const* const ptr_K_new;
  258. ShapeQKV const shape_K_new;
  259. StrideQK const stride_K_new;
  260. Element const* const ptr_V_new;
  261. StrideV const stride_V_new;
  262. Element const* const ptr_rotary_cos;
  263. ShapeRotary const shape_rotary;
  264. StrideRotary const stride_rotary_cos;
  265. Element const* const ptr_rotary_sin;
  266. StrideRotary const stride_rotary_sin;
  267. bool const is_rotary_interleaved;
  268. int const* const ptr_pagetable;
  269. ShapePageTable const shape_pagetable;
  270. StridePageTable const stride_pagetable;
  271. float const softmax_scale;
  272. float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale;
  273. StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale;
  274. int const window_size_left = -1, window_size_right = -1, sink_token_length = 0;
  275. float const softcap_val;
  276. int const num_splits;
  277. int const* const kv_batch_idx = nullptr;
  278. int const* const cu_seqlens_q = nullptr;
  279. int const* const cu_seqlens_k = nullptr;
  280. int const* const cu_seqlens_k_new = nullptr;
  281. int const* const seqused_q = nullptr;
  282. int const* const seqused_k = nullptr;
  283. int const* const leftpad_k = nullptr;
  284. };
  285. // Device side kernel params
  286. struct Params {
  287. Element const* const ptr_Q;
  288. ShapeQKV const shape_Q;
  289. StrideQK const stride_Q;
  290. ShapeQPacked const shape_Q_packed;
  291. StrideQPacked const stride_Q_packed;
  292. Element* const ptr_K;
  293. ShapeQKV const shape_K;
  294. StrideQK const stride_K;
  295. Element* const ptr_V;
  296. StrideV const stride_V;
  297. Element const* const ptr_K_new;
  298. ShapeQKV const shape_K_new;
  299. StrideQK const stride_K_new;
  300. Element const* const ptr_V_new;
  301. StrideV const stride_V_new;
  302. Element const* const ptr_rotary_cos;
  303. ShapeRotary const shape_rotary;
  304. StrideRotary const stride_rotary_cos;
  305. Element const* const ptr_rotary_sin;
  306. StrideRotary const stride_rotary_sin;
  307. bool const is_rotary_interleaved;
  308. int const* const ptr_pagetable;
  309. ShapePageTable const shape_pagetable;
  310. StridePageTable const stride_pagetable;
  311. cutlass::FastDivmod page_size_divmod;
  312. cutlass::FastDivmod qhead_per_khead_divmod;
  313. TMA_Q tma_load_Q;
  314. TMA_K tma_load_K;
  315. TMA_V tma_load_V;
  316. TMA_K tma_load_K_new;
  317. TMA_V tma_load_V_new;
  318. float const softmax_scale_log2;
  319. float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale;
  320. StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale;
  321. float const softcap_val;
  322. int const window_size_left, window_size_right, sink_token_length;
  323. int const num_splits;
  324. int const* const kv_batch_idx = nullptr;
  325. int const* const cu_seqlens_q = nullptr;
  326. int const* const cu_seqlens_k = nullptr;
  327. int const* const cu_seqlens_k_new = nullptr;
  328. int const* const seqused_q = nullptr;
  329. int const* const seqused_k = nullptr;
  330. int const* const leftpad_k = nullptr;
  331. };
  332. static Params
  333. to_underlying_arguments(Arguments const& args) {
  334. Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q);
  335. TMA_Q tma_load_Q = make_tma_copy_A_sm90(
  336. GmemTiledCopyQ{},
  337. mQ,
  338. SmemLayoutQ{},
  339. TileShape_MNK{},
  340. ClusterShape{}); // no mcast for Q
  341. Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K);
  342. TMA_K tma_load_K = make_tma_copy_B_sm90(
  343. GmemTiledCopyKV{},
  344. mK,
  345. take<0, 2>(SmemLayoutK{}),
  346. TileShape_MNK{},
  347. ClusterShape{}); // mcast along M mode for this N load, if any
  348. Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), select<1, 0, 2, 3>(args.shape_K), select<1, 0, 2, 3>(args.stride_V));
  349. TMA_V tma_load_V = make_tma_copy(
  350. GmemTiledCopyKV{},
  351. mV,
  352. take<0, 2>(SmemLayoutVt{}),
  353. select<2, 1>(TileShape_MNK{}),
  354. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  355. Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new);
  356. TMA_K tma_load_K_new = make_tma_copy_B_sm90(
  357. GmemTiledCopyKV{},
  358. cute::conditional_return<AppendKV>(mKnew, mK),
  359. take<0, 2>(SmemLayoutK{}),
  360. TileShape_MNK{},
  361. ClusterShape{}); // mcast along M mode for this N load, if any
  362. Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), select<1, 0, 2, 3>(args.shape_K_new), select<1, 0, 2, 3>(args.stride_V_new));
  363. TMA_V tma_load_V_new = make_tma_copy(
  364. GmemTiledCopyKV{},
  365. cute::conditional_return<AppendKV>(mVnew, mV),
  366. take<0, 2>(SmemLayoutVt{}),
  367. select<2, 1>(TileShape_MNK{}),
  368. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  369. // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size)
  370. int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K));
  371. auto const shape_Q_packed = cute::conditional_return<!PackGQA>(
  372. args.shape_Q,
  373. make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q))
  374. );
  375. auto const stride_Q_packed = cute::conditional_return<!PackGQA>(
  376. args.stride_Q,
  377. make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q))
  378. );
  379. if (get<1>(args.shape_rotary) > 0) {
  380. assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr);
  381. }
  382. assert(args.num_splits >= 1);
  383. // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val.
  384. // Right after this, we multiply by log2(e) before applying exp2.
  385. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val
  386. // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e)
  387. // (assigning it to params.softmax_scale_log2).
  388. return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed,
  389. args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.stride_V,
  390. args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new,
  391. args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos,
  392. args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved,
  393. args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable,
  394. cutlass::FastDivmod(int(get<0>(args.shape_K))),
  395. cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
  396. tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new,
  397. !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E),
  398. args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale,
  399. args.stride_q_descale, args.stride_k_descale, args.stride_v_descale,
  400. !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val,
  401. args.window_size_left, args.window_size_right, args.sink_token_length,
  402. !Split ? 1 : args.num_splits,
  403. args.kv_batch_idx,
  404. args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new,
  405. args.seqused_q, args.seqused_k, args.leftpad_k};
  406. }
  407. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  408. CUTLASS_DEVICE
  409. static void prefetch_tma_descriptors(Params const& params) {
  410. if constexpr (Use_TMA_Q) {
  411. cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor());
  412. }
  413. if constexpr (Use_TMA_KV) {
  414. cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor());
  415. cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor());
  416. }
  417. if constexpr (AppendKV) {
  418. cute::prefetch_tma_descriptor(params.tma_load_K_new.get_tma_descriptor());
  419. cute::prefetch_tma_descriptor(params.tma_load_V_new.get_tma_descriptor());
  420. }
  421. }
  422. CUTLASS_DEVICE
  423. cute::tuple<int, int> get_n_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info,
  424. int m_block, int bidb, int split_idx=0, int num_splits=1) {
  425. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  426. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  427. int const seqlen_k = seqlen_info.seqlen_k;
  428. int const seqlen_q = seqlen_info.seqlen_q;
  429. int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
  430. if constexpr (Is_causal || Is_local) {
  431. int m_idx_max = (m_block + 1) * kBlockM;
  432. // TODO: check off-by-1 error
  433. if (PackGQA) { m_idx_max = params.qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
  434. n_block_max = std::min(n_block_max,
  435. cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + params.window_size_right, kBlockN));
  436. }
  437. int n_block_min = 0;
  438. if constexpr (Is_local) {
  439. int m_idx_min = m_block * kBlockM;
  440. if (PackGQA) { m_idx_min = params.qhead_per_khead_divmod.divide(m_idx_min); }
  441. n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - params.window_size_left) / kBlockN);
  442. }
  443. // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
  444. if constexpr (Split) {
  445. int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits);
  446. n_block_min = n_block_min + split_idx * num_n_blocks_per_split;
  447. n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);
  448. }
  449. // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
  450. return {n_block_min, n_block_max};
  451. }
  452. template <typename SchedulerPrefetch, typename SharedStorage>
  453. CUTLASS_DEVICE void
  454. load(Params const& params,
  455. MainloopPipelineK pipeline_k,
  456. MainloopPipelineV pipeline_v,
  457. MainloopPipelineVt pipeline_vt,
  458. PipelineState& smem_pipe_write,
  459. SharedStorage &shared_storage,
  460. SchedulerPrefetch const& scheduler_prefetch,
  461. SeqlenInfo_t const& seqlen_info,
  462. cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
  463. int &work_idx
  464. ) {
  465. auto [m_block, bidh, bidb, split_idx] = block_coord;
  466. auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits);
  467. // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access.
  468. if constexpr (Is_causal || Is_local || Varlen || Split) {
  469. if (n_block_max <= n_block_min) {
  470. scheduler_prefetch();
  471. return;
  472. }
  473. }
  474. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});
  475. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  476. Tensor sK_pi = as_position_independent_swizzle_tensor(sK);
  477. // as_position_independent_swizzle_tensor makes address calculation easier when we do LDSM & STSM to transpose.
  478. // But it requires smem_vt and smem_v to be aligned to e.g 512 bytes.
  479. Tensor sVt = [&] {
  480. if constexpr (!Transpose_V) {
  481. return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{});
  482. } else {
  483. return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{}));
  484. }
  485. }();
  486. // Only used if Transpose_V
  487. Tensor sV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}));
  488. // Only used if we're using cp.async to load V
  489. Tensor sVcpasync = [&] {
  490. if constexpr (!Transpose_V) {
  491. return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVCpAsync{}));
  492. } else {
  493. return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{}));
  494. }
  495. }();
  496. int const thread_idx = threadIdx.x % NumProducerThreads;
  497. int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;
  498. int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb];
  499. // Prepare the TMA loads
  500. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  501. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  502. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  503. bool const is_varlen_q = Varlen && params.cu_seqlens_q;
  504. bool const is_varlen_k = Varlen && params.cu_seqlens_k;
  505. Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
  506. Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);
  507. Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K))(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);
  508. Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  509. // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); }
  510. Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  511. Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _)
  512. auto block_tma_Q = params.tma_load_Q.get_slice(_0{});
  513. Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA)
  514. Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA)
  515. // tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually
  516. auto block_tma_K = params.tma_load_K.get_slice(cluster_local_block_id.x);
  517. Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k)
  518. Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K.partition_D(sK)); // (TMA, PIPE)
  519. auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x);
  520. Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k)
  521. Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE)
  522. using PagedKVManager_t = PagedKVManager<get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>;
  523. PagedKVManager_t paged_kv_manager(
  524. params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable,
  525. params.ptr_K, params.shape_K, params.stride_K,
  526. params.ptr_V, params.stride_V,
  527. params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k
  528. );
  529. // Set up for transposing V, only used if Transpose_V
  530. S2RTiledCopyVt s2r_tiled_copy_vt;
  531. R2STiledCopyV r2s_tiled_copy_v;
  532. auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(thread_idx);
  533. auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(thread_idx);
  534. // flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8, kStages)
  535. Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32, kStages)
  536. // flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64), kStages)
  537. Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64), kStages)
  538. CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_));
  539. CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_));
  540. CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_));
  541. CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_));
  542. CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_));
  543. CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_));
  544. // Faster to have 2 LDSM.T, byte permute, STSM for better ILP
  545. static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1;
  546. Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_) - 1>(tTranssVt_), Shape<Underscore, Int<Transpose_ILP>>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages)
  547. Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_) - 1>(tTranssV_), Shape<Underscore, Int<Transpose_ILP>>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages)
  548. auto transpose_V = [&](int stage) {
  549. if constexpr (Transpose_V) {
  550. #pragma unroll
  551. for (int i = 0; i < size<1, 1>(tTranssVt); ++i) {
  552. Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{}), _0{}));
  553. static_assert(size<0>(tTransrV) == 16);
  554. Tensor tTransrV_64 = recast<uint2>(tTransrV);
  555. cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i), stage), tTransrV);
  556. #pragma unroll
  557. for (int j = 0; j < size(tTransrV_64); ++j) {
  558. uint32_t upper = tTransrV_64[j].x;
  559. uint32_t lower = tTransrV_64[j].y;
  560. tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420);
  561. tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531);
  562. }
  563. cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i), stage));
  564. }
  565. }
  566. };
  567. uint16_t mcast_mask_kv = 0;
  568. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
  569. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  570. for (int m = 0; m < size<0>(block_layout); ++m) {
  571. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  572. }
  573. }
  574. auto load_K = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) {
  575. pipeline_k.producer_acquire(smem_pipe_write);
  576. if constexpr (!PagedKV) {
  577. copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST),
  578. tKgK_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index()));
  579. } else {
  580. constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;
  581. paged_kv_manager.template load_K<Seqlenk_mask>(n_block, sK_pi(_, _, smem_pipe_write.index()));
  582. pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
  583. }
  584. };
  585. auto load_V = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) {
  586. auto pipeline_v_load = cute::conditional_return<!Transpose_V>(pipeline_v, pipeline_vt);
  587. pipeline_v_load.producer_acquire(smem_pipe_write);
  588. if constexpr (!PagedKV) {
  589. copy(params.tma_load_V.with(*pipeline_v_load.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST),
  590. tVgVt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index()));
  591. } else {
  592. constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;
  593. paged_kv_manager.template load_V<Seqlenk_mask>(n_block, sVcpasync(_, _, smem_pipe_write.index()));
  594. pipeline_v_load.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
  595. }
  596. };
  597. auto copy_Vt_to_V = [&] (auto const& smem_pipe_write) {
  598. // Instead of maintaining smem_pipe_read as a separate variable, we can just use smem_pipe_write,
  599. // and exploit the invariance that smem_pipe_write.phase() == smem_pipe_read.phase() ^ 1.
  600. // This saves 1 or 2 registers.
  601. PipelineState smem_pipe_read{smem_pipe_write.index(), smem_pipe_write.phase() ^ 1, smem_pipe_write.count()};
  602. pipeline_vt.consumer_wait(smem_pipe_read);
  603. pipeline_v.producer_acquire(smem_pipe_write);
  604. transpose_V(smem_pipe_write.index());
  605. // SMEM fence to make sure V is transposed before math
  606. cutlass::arch::fence_view_async_shared();
  607. pipeline_v.producer_commit(smem_pipe_write);
  608. // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized
  609. // before calling. Without this we get race conditions.
  610. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::ProducerWG) /*id*/);
  611. pipeline_vt.consumer_release(smem_pipe_read);
  612. };
  613. int n_block = n_block_max - 1;
  614. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  615. // If this is true, we're guaranteed that only the first warp will execute this function
  616. static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
  617. bool should_load_KV = !Use_TMA_KV || ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync());
  618. if (should_load_KV) {
  619. if constexpr (PagedKV) {
  620. paged_kv_manager.template load_page_table<true /*Seqlenk_mask*/, true /*First_iter*/>(n_block);
  621. }
  622. if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); }
  623. // if (thread_idx == 0) { printf("Producer: main load, before load_K, index = %d\n", smem_pipe_write.index());}
  624. load_K(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/);
  625. // if (thread_idx == 0) { printf("Producer: main load, after load K, index = %d\n", smem_pipe_write.index());}
  626. }
  627. if constexpr (Use_TMA_Q) {
  628. // Wait for the MMA warpgroups to signal that smem_q is ready
  629. if (SingleProducerWarp || warp_idx_in_warpgroup == 0) {
  630. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);
  631. }
  632. if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) {
  633. shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  634. copy(params.tma_load_Q.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST),
  635. tQgQ, tQsQ);
  636. }
  637. } else { // Load Q with cp.async
  638. cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);
  639. Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0);
  640. Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ);
  641. using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>;
  642. PackGQAt::load_Q(mQ, sQ_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block);
  643. auto &barrier_Q = shared_storage.pipelines.barrier_Q;
  644. cutlass::arch::cpasync_barrier_arrive(reinterpret_cast<uint64_t*>(&barrier_Q));
  645. barrier_Q.arrive();
  646. }
  647. // Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem
  648. // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the
  649. // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O.
  650. // if (thread_idx == 0) { printf("Producer: main load, before barrier_O, work_idx = %d\n", work_idx);}
  651. shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2);
  652. // if (thread_idx == 0) { printf("Producer: main load, after barrier_O\n");}
  653. if constexpr (!Transpose_V && !IntraWGOverlap) {
  654. if (should_load_KV) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); }
  655. }
  656. int n_block_prev = n_block;
  657. --n_block;
  658. #pragma unroll (!Transpose_V && Use_TMA_KV ? 2 : 1)
  659. for (; n_block >= n_block_min; --n_block) {
  660. PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind
  661. ++smem_pipe_write;
  662. if (should_load_KV) {
  663. if constexpr (PagedKV) {
  664. paged_kv_manager.template load_page_table<false /*Seqlenk_mask*/>(n_block);
  665. }
  666. if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); }
  667. load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/);
  668. if constexpr (!Transpose_V) {
  669. if constexpr (IntraWGOverlap) {
  670. load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/);
  671. } else {
  672. load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/);
  673. }
  674. }
  675. }
  676. n_block_prev = n_block;
  677. if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); }
  678. }
  679. // if constexpr (Is_local) {
  680. // Disable sink token code for now
  681. if constexpr (false && Is_local) {
  682. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  683. int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN);
  684. #pragma unroll 1
  685. for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) {
  686. PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind
  687. ++smem_pipe_write;
  688. if (should_load_KV) {
  689. if constexpr (PagedKV) {
  690. paged_kv_manager.template load_page_table<false /*Seqlenk_mask*/>(n_block);
  691. }
  692. if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); }
  693. load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/);
  694. if constexpr (!Transpose_V) {
  695. if constexpr (IntraWGOverlap) {
  696. load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/);
  697. } else {
  698. load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/);
  699. }
  700. }
  701. }
  702. n_block_prev = n_block;
  703. if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); }
  704. }
  705. }
  706. scheduler_prefetch();
  707. if constexpr (!Transpose_V && IntraWGOverlap) {
  708. if (should_load_KV) { load_V(n_block_prev, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); }
  709. }
  710. if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write); }
  711. ++smem_pipe_write;
  712. // At the end, all threads have the correct smem_pipe_write.
  713. ++work_idx;
  714. }
  715. template <typename SharedStorage>
  716. CUTLASS_DEVICE void
  717. load_tail(MainloopPipelineK pipeline_k, MainloopPipelineV pipeline_v, MainloopPipelineVt pipeline_vt,
  718. PipelineState& smem_pipe_write, SharedStorage &shared_storage, int const work_idx) {
  719. // If we don't wait for barrier_O here, when using Cluster, CTA0 might exit early and CTA1 will
  720. // try to arrive on barrier_O of CTA0, causing "unspecified launch failure".
  721. shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2);
  722. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  723. // Issue the epilogue waits
  724. // TODO: check if this should be called by 1 thread or more
  725. if (warp_idx_in_warpgroup == 0 && cute::elect_one_sync()) {
  726. /* This helps avoid early exit of blocks in Cluster
  727. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  728. * then would just be acquired since the phase was still inverted from make_producer_start_state
  729. */
  730. pipeline_k.producer_tail(smem_pipe_write);
  731. pipeline_v.producer_tail(smem_pipe_write);
  732. if constexpr (Transpose_V) { pipeline_vt.producer_tail(smem_pipe_write); }
  733. }
  734. }
  735. CUTLASS_DEVICE void
  736. warp_scheduler_barrier_sync() {
  737. if constexpr (UseSchedulerBarrier) {
  738. cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/);
  739. }
  740. }
  741. CUTLASS_DEVICE void
  742. warp_scheduler_barrier_arrive() {
  743. if constexpr (UseSchedulerBarrier) {
  744. static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
  745. int const cur_WG = flash::canonical_warp_group_idx_nosync() - 1;
  746. int const next_WG = NumMmaWarpGroups == 2
  747. ? 1 - cur_WG
  748. : (cur_WG < NumMmaWarpGroups - 1 ? cur_WG + 1 : 0);
  749. cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) + next_WG /*id*/);
  750. }
  751. }
  752. CUTLASS_DEVICE void
  753. mma_init() {
  754. // Tell producers that smem_q is ready
  755. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);
  756. if constexpr (UseSchedulerBarrier) {
  757. // We have NamedBarrier for up to 3 WGs
  758. static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
  759. // WG1 needs the very first signal to start
  760. if (flash::canonical_warp_group_idx_nosync() == 1) {
  761. cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) /*id*/);
  762. }
  763. }
  764. }
  765. template <typename SharedStorage, typename FrgTensorO, typename Softmax>
  766. CUTLASS_DEVICE bool
  767. mma(Params const& params,
  768. MainloopPipelineK pipeline_k,
  769. MainloopPipelineV pipeline_v,
  770. PipelineState& smem_pipe_read,
  771. FrgTensorO& tOrO,
  772. Softmax& softmax,
  773. int const thread_idx,
  774. int &work_idx,
  775. SeqlenInfo_t const& seqlen_info,
  776. cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
  777. SharedStorage& shared_storage
  778. ) {
  779. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  780. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  781. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  782. // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda
  783. int const m_block = get<0>(block_coord);
  784. int const bidh = get<1>(block_coord);
  785. int const bidb = get<2>(block_coord);
  786. int const split_idx = get<3>(block_coord);
  787. int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;
  788. auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits);
  789. // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier
  790. if constexpr (Is_causal || Is_local || Varlen || Split) {
  791. if (n_block_max <= n_block_min) { return false; }
  792. }
  793. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});
  794. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  795. Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{});
  796. Tensor sP = [&] {
  797. if constexpr (Mma1_is_RS) {
  798. // We might not have smem_p if !Mma1_is_RS1, just use smem_q as a placeholder since we don't use it
  799. return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutP{});
  800. } else {
  801. return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{});
  802. }
  803. }();
  804. if constexpr (!Mma0_is_RS) {
  805. static_assert(stride<0>(typename TiledMma0::ALayout{}) == 0 and
  806. stride<0>(typename TiledMma0::BLayout{}) == 0 and
  807. size<0>(typename TiledMma0::ALayout{}) == cutlass::NumThreadsPerWarpGroup and
  808. size<0>(typename TiledMma0::BLayout{}) == cutlass::NumThreadsPerWarpGroup,
  809. "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
  810. }
  811. constexpr int MmaWarpGroups = size(TiledMma0{}) / cutlass::NumThreadsPerWarpGroup;
  812. Layout warp_group_thread_layout = make_layout(make_shape(Int<MmaWarpGroups>{}),
  813. make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));
  814. int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
  815. TiledMma0 tiled_mma0;
  816. TiledMma1 tiled_mma1;
  817. auto wg_mma0 = tiled_mma0.get_slice(warp_group_thread_layout(warp_group_idx));
  818. auto wg_mma1 = tiled_mma1.get_slice(warp_group_thread_layout(warp_group_idx));
  819. auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma0);
  820. auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
  821. // Allocate "fragments/descriptors"
  822. Tensor tSrQ = wg_mma0.partition_fragment_A(sQ);
  823. Tensor tSrK = wg_mma0.partition_fragment_B(sK);
  824. Tensor tOrV = wg_mma1.partition_fragment_B(sV);
  825. Tensor tOsP = wg_mma1.partition_fragment_A(sP);
  826. Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP));
  827. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  828. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  829. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  830. };
  831. // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter
  832. clear(tOrO);
  833. // tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
  834. int const seqlen_q = seqlen_info.seqlen_q;
  835. int const seqlen_k = seqlen_info.seqlen_k;
  836. int n_block = n_block_max - 1;
  837. flash::Mask<kBlockM, kBlockN, PackGQA, TiledMma0> mask(
  838. thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length,
  839. params.qhead_per_khead_divmod
  840. );
  841. float softcap_val = params.softcap_val;
  842. if constexpr (Has_softcap && Is_FP8) {
  843. float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)];
  844. float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)];
  845. softcap_val *= q_descale * k_descale;
  846. }
  847. // Softcapping needs to happen before masking since if we apply after masking, softcapping can turn
  848. // -inf to e.g. -50.0, which can affect the attention softmax.
  849. auto scoremod_premask_fn = [&](auto& tSrS) {
  850. if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); }
  851. };
  852. auto &barrier_Q = shared_storage.pipelines.barrier_Q;
  853. if constexpr (!AppendKV) {
  854. barrier_Q.wait(work_idx % 2);
  855. } else {
  856. if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q
  857. int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k;
  858. using Rotary_t = Rotary<kBlockM, kHeadDim, NumMmaThreads, Element, !(Is_causal || Is_local) /*FixedPosition*/>;
  859. Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos,
  860. params.ptr_rotary_sin, params.stride_rotary_sin,
  861. params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary);
  862. Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ);
  863. int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
  864. if (params.is_rotary_interleaved) {
  865. auto [tRrCos, tRrSin] = cute::conditional_return<!PackGQA>(
  866. rotary.template load_cos_sin<true /*kInterleaved*/>(m_block),
  867. rotary.template load_cos_sin_packgqa<true /*kInterleaved*/>(m_block, params.qhead_per_khead_divmod)
  868. );
  869. barrier_Q.wait(work_idx % 2);
  870. rotary.apply_Q_interleaved(sQ_pi, tRrCos, tRrSin, m_block, qhead_per_khead);
  871. } else {
  872. auto [tRrCosCont, tRrSinCont] = cute::conditional_return<!PackGQA>(
  873. rotary.template load_cos_sin<false /*kInterleaved*/>(m_block),
  874. rotary.template load_cos_sin_packgqa<false /*kInterleaved*/>(m_block, params.qhead_per_khead_divmod)
  875. );
  876. barrier_Q.wait(work_idx % 2);
  877. rotary.apply_Q_contiguous(sQ_pi, tRrCosCont, tRrSinCont, m_block, qhead_per_khead);
  878. }
  879. // SMEM fence to make sure the rotated Q is visible to GMMA
  880. cutlass::arch::fence_view_async_shared();
  881. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::QueryRotated) /*id*/);
  882. } else {
  883. barrier_Q.wait(work_idx % 2);
  884. }
  885. }
  886. if constexpr (Mma0_is_RS) {
  887. using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  888. auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma0);
  889. auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx);
  890. Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
  891. Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(cute::as_position_independent_swizzle_tensor(sQ));
  892. cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view);
  893. }
  894. // TODO: check the case where n_block_max <= n_block_min but there are sink tokens
  895. if constexpr (IntraWGOverlap) {
  896. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  897. consumer_wait(pipeline_k, smem_pipe_read);
  898. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  899. warpgroup_wait<0>();
  900. pipeline_k.consumer_release(smem_pipe_read);
  901. scoremod_premask_fn(tSrS);
  902. mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block);
  903. Tensor scores_scale = softmax.template max_get_scale</*Is_first=*/true, /*Check_inf=*/true>(tSrS);
  904. softmax.template online_softmax</*Is_first=*/true, /*Check_inf=*/true>(tSrS);
  905. if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); }
  906. Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<TiledMma1>(tSrS.layout()));
  907. Tensor tOrP = make_tensor_like<Element>(tOrP_acc);
  908. convert_type_out(tOrP_acc, tOrP);
  909. if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); }
  910. if constexpr (!Mma1_is_RS) {
  911. cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP);
  912. cutlass::arch::fence_view_async_shared();
  913. __syncwarp(); // Only need syncwarp since each warp is using its own P values for Mma1
  914. }
  915. --n_block;
  916. // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block.
  917. auto fwd_step = [&](int const n_block, auto mask_fn, auto check_inf_type) {
  918. static constexpr bool Check_inf = decltype(check_inf_type)::value;
  919. PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count());
  920. ++smem_pipe_read;
  921. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  922. if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); }
  923. warp_scheduler_barrier_sync();
  924. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  925. if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); }
  926. if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); }
  927. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, cute::conditional_return<Mma1_is_RS>(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  928. warp_scheduler_barrier_arrive();
  929. warpgroup_wait<1>();
  930. pipeline_k.consumer_release(smem_pipe_read); // release K
  931. scoremod_premask_fn(tSrS);
  932. mask_fn(tSrS, n_block);
  933. cute::copy(softmax.template max_get_scale</*Is_first=*/false, Check_inf>(tSrS), scores_scale);
  934. softmax.template online_softmax</*Is_first=*/false, Check_inf>(tSrS);
  935. warpgroup_wait<0>();
  936. pipeline_v.consumer_release(smem_pipe_read_v); // release V
  937. if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); }
  938. convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP);
  939. if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); }
  940. if constexpr (!Mma1_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); }
  941. if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); }
  942. if constexpr (!Mma1_is_RS) {
  943. cutlass::arch::fence_view_async_shared();
  944. __syncwarp();
  945. }
  946. };
  947. if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking
  948. auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
  949. int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM);
  950. int const n_block_min_causal_local_mask =
  951. std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN);
  952. #pragma unroll 1
  953. for (; n_block >= n_block_min_causal_local_mask; --n_block) {
  954. fwd_step(n_block, mask_fn, cute::true_type{} /*check_inf*/);
  955. }
  956. }
  957. int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;
  958. int const n_block_min_before_local_mask = !Is_local
  959. ? n_block_min
  960. : std::max(n_block_min,
  961. cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN));
  962. auto no_mask_fn = [](auto& tSrS, int n_block) { };
  963. #pragma unroll 1
  964. for (; n_block >= n_block_min_before_local_mask; --n_block) {
  965. fwd_step(n_block, no_mask_fn, cute::false_type{} /*check_inf*/);
  966. }
  967. // Separate masking iterations on the left for local attention
  968. if constexpr (Is_local) {
  969. auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, false /*Causal_mask*/, Is_local>(tSrS, m_block, n_block); };
  970. #pragma unroll 1
  971. for (; n_block >= n_block_min; --n_block) {
  972. fwd_step(n_block, local_mask_fn, cute::bool_constant<Is_local>{} /*check_inf*/);
  973. }
  974. // Disable sink token code for now
  975. // int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN);
  976. // #pragma unroll 1
  977. // for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) {
  978. // fwd_step(n_block, local_mask_fn, cute::bool_constant<Is_local>{} /*check_inf*/);
  979. // }
  980. }
  981. // Tell producers that smem_q is ready
  982. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);
  983. if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); }
  984. consumer_wait(pipeline_v, smem_pipe_read);
  985. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, cute::conditional_return<Mma1_is_RS>(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  986. float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)];
  987. cute::copy(softmax.finalize(v_descale), scores_scale);
  988. warpgroup_wait<0>();
  989. pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang
  990. softmax.rescale_o(tOrO, scores_scale);
  991. if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); }
  992. ++smem_pipe_read;
  993. } else { // No intra-WG overlap
  994. warp_scheduler_barrier_sync();
  995. auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) {
  996. static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value;
  997. static constexpr bool Check_inf = decltype(check_inf_type)::value;
  998. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  999. consumer_wait(pipeline_k, smem_pipe_read);
  1000. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  1001. warp_scheduler_barrier_arrive();
  1002. warpgroup_wait<0>();
  1003. pipeline_k.consumer_release(smem_pipe_read); // release K
  1004. scoremod_premask_fn(tSrS);
  1005. mask_fn(tSrS, n_block);
  1006. Tensor scores_scale = softmax.template max_get_scale</*Is_first=*/Is_first_iter, Check_inf>(tSrS);
  1007. softmax.template online_softmax</*Is_first=*/Is_first_iter, Check_inf>(tSrS);
  1008. if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); }
  1009. Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<TiledMma1>(tSrS.layout()));
  1010. Tensor tOrP = make_tensor_like<Element>(tOrP_acc);
  1011. convert_type_out(tOrP_acc, tOrP);
  1012. if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); }
  1013. if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); }
  1014. consumer_wait(pipeline_v, smem_pipe_read);
  1015. warp_scheduler_barrier_sync();
  1016. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  1017. pipeline_v.consumer_release(smem_pipe_read); // release V
  1018. ++smem_pipe_read;
  1019. };
  1020. auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
  1021. fwd_step(n_block, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/);
  1022. --n_block;
  1023. if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking
  1024. auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
  1025. int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM);
  1026. int const n_block_min_causal_local_mask =
  1027. std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN);
  1028. #pragma unroll 1
  1029. for (; n_block >= n_block_min_causal_local_mask; --n_block) {
  1030. fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/);
  1031. }
  1032. }
  1033. int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;
  1034. int const n_block_min_before_local_mask = !Is_local
  1035. ? n_block_min
  1036. : std::max(n_block_min,
  1037. cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN));
  1038. auto no_mask_fn = [](auto& tSrS, int n_block) { };
  1039. #pragma unroll 1
  1040. for (; n_block >= n_block_min_before_local_mask; --n_block) {
  1041. fwd_step(n_block, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/);
  1042. }
  1043. // Separate masking iterations on the left for local attention
  1044. if constexpr (Is_local) {
  1045. auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, false /*Causal_mask*/, Is_local>(tSrS, m_block, n_block); };
  1046. #pragma unroll 1
  1047. for (; n_block >= n_block_min; --n_block) {
  1048. fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant<Is_local>{} /*check_inf*/);
  1049. }
  1050. // Disable sink token code for now
  1051. // int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN);
  1052. // #pragma unroll 1
  1053. // for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) {
  1054. // fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant<Is_local>{} /*check_inf*/);
  1055. // }
  1056. }
  1057. warp_scheduler_barrier_arrive();
  1058. // Tell producers that smem_q is ready
  1059. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);
  1060. float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)];
  1061. Tensor scores_scale = softmax.finalize(v_descale);
  1062. softmax.rescale_o(tOrO, scores_scale);
  1063. if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); }
  1064. }
  1065. ++work_idx;
  1066. return true;
  1067. }
  1068. CUTLASS_DEVICE
  1069. cute::tuple<int, int> get_n_block_k_new_min_max(Params const& params, SeqlenInfo_t const& seqlen_info,
  1070. int m_block, int bidb, int split_idx=0, int num_splits=1) {
  1071. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  1072. auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, num_splits);
  1073. int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
  1074. int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
  1075. int const n_block_new_min = idx_k_new_min / kBlockN;
  1076. int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
  1077. // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
  1078. return {n_block_new_min, n_block_new_max};
  1079. }
  1080. template <typename SharedStorage>
  1081. CUTLASS_DEVICE bool
  1082. load_kv_new(Params const& params,
  1083. MainloopPipelineKVNew pipeline_k_new,
  1084. MainloopPipelineKVNew pipeline_v_new,
  1085. PipelineState& smem_pipe_write,
  1086. SharedStorage &shared_storage,
  1087. SeqlenInfo_t const& seqlen_info,
  1088. cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
  1089. int const work_idx
  1090. ) {
  1091. auto [m_block, bidh, bidb, split_idx] = block_coord;
  1092. auto [n_block_new_min, n_block_new_max] = get_n_block_k_new_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits);
  1093. if (n_block_new_max <= n_block_new_min) { return false; }
  1094. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  1095. Tensor sVt = [&] {
  1096. if constexpr (!Transpose_V) {
  1097. return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{});
  1098. } else {
  1099. return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{});
  1100. }
  1101. }();
  1102. // int const thread_idx = threadIdx.x % NumProducerThreads;
  1103. int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;
  1104. // Prepare the TMA loads
  1105. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  1106. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  1107. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  1108. bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new;
  1109. Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0);
  1110. Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K_new))(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0);
  1111. Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  1112. Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _)
  1113. auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x);
  1114. Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k)
  1115. Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K_new.partition_D(sK)); // (TMA, PIPE)
  1116. auto block_tma_V_new = params.tma_load_V_new.get_slice(cluster_local_block_id.x);
  1117. Tensor tVgVnewt_TMA = group_modes<0, 3>(block_tma_V_new.partition_S(gVnewt_TMA)); // (TMA, k)
  1118. Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V_new.partition_D(sVt)); // (TMA, PIPE)
  1119. uint16_t mcast_mask_kv = 0;
  1120. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
  1121. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  1122. for (int m = 0; m < size<0>(block_layout); ++m) {
  1123. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  1124. }
  1125. }
  1126. auto load_K_new = [&] (int const n_block, auto const& smem_pipe_write) {
  1127. pipeline_k_new.producer_acquire(smem_pipe_write);
  1128. copy(params.tma_load_K_new.with(*pipeline_k_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST),
  1129. tKgKnew_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index()));
  1130. };
  1131. auto load_V_new = [&] (int const n_block, auto const& smem_pipe_write) {
  1132. pipeline_v_new.producer_acquire(smem_pipe_write);
  1133. copy(params.tma_load_V_new.with(*pipeline_v_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST),
  1134. tVgVnewt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index()));
  1135. };
  1136. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  1137. // If this is true, we're guaranteed that only the first warp will execute this function
  1138. static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
  1139. bool should_load_KV = (SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync();
  1140. int n_block = n_block_new_max - 1;
  1141. // Need to wait for barrier_O even before load_K_new since the pipelines for AppendKV
  1142. // and the main attention are not the same. We want to make sure the consumers
  1143. // have finished reading all smem_k and smem_v for the previous iteration.
  1144. shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2);
  1145. if (should_load_KV) { load_K_new(n_block, smem_pipe_write); }
  1146. // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); }
  1147. if (should_load_KV) { load_V_new(n_block, smem_pipe_write); }
  1148. // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); }
  1149. ++smem_pipe_write;
  1150. --n_block;
  1151. // if (thread_idx == 0) { printf("Producer: before for loop\n"); }
  1152. #pragma unroll 1
  1153. for (; n_block >= n_block_new_min; --n_block) {
  1154. if (should_load_KV) {
  1155. load_K_new(n_block, smem_pipe_write);
  1156. // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); }
  1157. load_V_new(n_block, smem_pipe_write);
  1158. // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); }
  1159. }
  1160. ++smem_pipe_write;
  1161. }
  1162. // if (thread_idx == 0) { printf("Producer: after for loop\n"); }
  1163. // At the end, all threads have the correct smem_pipe_write.
  1164. return true;
  1165. }
  1166. template <typename SharedStorage>
  1167. CUTLASS_DEVICE bool
  1168. store_kv_new(Params const& params,
  1169. MainloopPipelineKVNew pipeline_k_new,
  1170. MainloopPipelineKVNew pipeline_v_new,
  1171. PipelineState& smem_pipe_read,
  1172. int const thread_idx,
  1173. SharedStorage &shared_storage,
  1174. SeqlenInfo_t const& seqlen_info,
  1175. cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord
  1176. ) {
  1177. auto [m_block, bidh, bidb, split_idx] = block_coord;
  1178. auto [n_block_new_min, n_block_new_max] = get_n_block_k_new_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits);
  1179. if (n_block_new_max <= n_block_new_min) { return false; }
  1180. // as_position_independent_swizzle_tensor makes address calculation easier
  1181. Tensor sK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}));
  1182. // We want to use SmemLayoutVCpAsync to have shape (kBlockN, kHeadDim) instead of (kHeadDim, kBlockN)
  1183. Tensor sV = [&] {
  1184. if constexpr (!Transpose_V) {
  1185. return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVCpAsync{}));
  1186. } else {
  1187. return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{}));
  1188. }
  1189. }();
  1190. int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;
  1191. int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb];
  1192. bool const is_varlen_k = Varlen && params.cu_seqlens_k;
  1193. Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);
  1194. Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);
  1195. int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og;
  1196. Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  1197. Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  1198. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  1199. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  1200. int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k;
  1201. int const seqlen_k_new = seqlen_info.seqlen_k_new;
  1202. using Rotary_t = Rotary<kBlockN, kHeadDim, NumMmaThreads, Element>;
  1203. Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos,
  1204. params.ptr_rotary_sin, params.stride_rotary_sin,
  1205. params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary);
  1206. using PagedKVManager_t = PagedKVManager<get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>;
  1207. PagedKVManager_t paged_kv_manager(
  1208. params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable,
  1209. params.ptr_K, params.shape_K, params.stride_K,
  1210. params.ptr_V, params.stride_V,
  1211. params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k
  1212. // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position
  1213. );
  1214. if constexpr (UseSchedulerBarrier) {
  1215. // WG1 already got the very first signal from mma_init(), but we'll be using the same NamedBarrier.
  1216. // So we'll need to "cancel it out" here and then re-signal it at the end.
  1217. if (flash::canonical_warp_group_idx_nosync() == 1) {
  1218. cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) /*id*/);
  1219. }
  1220. }
  1221. static_assert(std::is_same_v<GmemLayoutAtom, typename Rotary_t::LayoutAtom>);
  1222. static_assert(!PagedKV || std::is_same_v<GmemLayoutAtom, typename PagedKVManager_t::GmemLayoutAtomKVCpAsync>);
  1223. GmemTiledCopyAppendKV gmem_tiled_copy_kv;
  1224. auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_thread_slice(thread_idx);
  1225. Tensor tKsK = gmem_thr_copy_kv.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  1226. Tensor tKgK = gmem_thr_copy_kv.partition_D(gK);
  1227. Tensor tVsV = gmem_thr_copy_kv.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  1228. Tensor tVgV = gmem_thr_copy_kv.partition_D(gV);
  1229. Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k)
  1230. Tensor tKcK = gmem_thr_copy_kv.partition_D(cK);
  1231. Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));
  1232. #pragma unroll
  1233. for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(params.shape_K); }
  1234. auto store_K = [&] (int const n_block, auto const& smem_pipe_read) {
  1235. int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN);
  1236. if (get<1>(params.shape_rotary) <= 0) {
  1237. pipeline_k_new.consumer_wait(smem_pipe_read);
  1238. Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_read.index());
  1239. if constexpr (!PagedKV) {
  1240. Tensor tKgK_cur = tKgK(_, _, _, n_block);
  1241. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  1242. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  1243. gmem_tiled_copy_kv, tKsK_cur, tKgK_cur, tKcK, tKpK, std::min(seqlen_k_new - n_block * kBlockN, kBlockN)
  1244. );
  1245. } else {
  1246. paged_kv_manager.store_K(n_block, tKsK_cur);
  1247. }
  1248. } else {
  1249. Tensor gK_cur = gK(_, _, n_block);
  1250. auto tPrKPtr = cute::conditional_return<PagedKV>(paged_kv_manager.compute_K_ptr(), nullptr);
  1251. if (params.is_rotary_interleaved) {
  1252. auto [tRrCos, tRrSin] = rotary.template load_cos_sin<true /*kInterleaved*/>(n_block);
  1253. pipeline_k_new.consumer_wait(smem_pipe_read);
  1254. rotary.template apply_K_interleaved<PagedKV>(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCos, tRrSin, tPrKPtr, n_block);
  1255. } else {
  1256. auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin<false /*kInterleaved*/>(n_block);
  1257. pipeline_k_new.consumer_wait(smem_pipe_read);
  1258. rotary.template apply_K_contiguous<PagedKV>(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K));
  1259. }
  1260. }
  1261. // Without this sync I'm getting race condition when seqlen_k is large
  1262. cutlass::arch::fence_view_async_shared();
  1263. // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized
  1264. // before calling.
  1265. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/);
  1266. pipeline_k_new.consumer_release(smem_pipe_read);
  1267. // if (thread_idx == 0) { print_tensor(tKpK); printf("\n"); printf("seqlen_limit = %d\n", seqlen_k_new - n_block * kBlockN);}
  1268. };
  1269. auto store_V = [&] (int const n_block, auto const& smem_pipe_read) {
  1270. pipeline_v_new.consumer_wait(smem_pipe_read);
  1271. int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN);
  1272. Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_read.index());
  1273. if constexpr (!PagedKV) {
  1274. Tensor tVgV_cur = tVgV(_, _, _, n_block);
  1275. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  1276. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  1277. gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tKcK, tKpK, n_limit);
  1278. } else {
  1279. paged_kv_manager.store_V(n_block, tVsV_cur);
  1280. }
  1281. cutlass::arch::fence_view_async_shared();
  1282. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/);
  1283. pipeline_v_new.consumer_release(smem_pipe_read);
  1284. };
  1285. #pragma unroll 1
  1286. for (int n_block = n_block_new_max - 1; n_block >= n_block_new_min; --n_block) {
  1287. if constexpr (PagedKV) { paged_kv_manager.template load_page_table<true /*Seqlenk_mask*/>(n_block); }
  1288. store_K(n_block, smem_pipe_read);
  1289. // if (thread_idx == 0) { printf("Done storing K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); }
  1290. store_V(n_block, smem_pipe_read);
  1291. // if (thread_idx == 0) { printf("Done storing V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); }
  1292. ++smem_pipe_read;
  1293. }
  1294. // if (thread_idx == 0) { printf("After for loop\n"); }
  1295. // Re-signaling the NamedBarrier that we "canceled out"
  1296. if constexpr (UseSchedulerBarrier) {
  1297. if (flash::canonical_warp_group_idx_nosync() == 1) {
  1298. cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) /*id*/);
  1299. }
  1300. }
  1301. return true;
  1302. }
  1303. };
  1304. } // namespace flash