mainloop_fwd_sm90_tma_gmma_ws.hpp 54 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include "cutlass/pipeline/pipeline.hpp"
  10. #include "cute/tensor.hpp"
  11. #include "cutlass/gemm/collective/collective_builder.hpp"
  12. #include "named_barrier.hpp"
  13. #include "utils.h"
  14. namespace flash {
  15. using namespace cute;
  16. template <int Stages, class ClusterShape_, class TileShape_MNK_, class Element_, class ElementAccum_, class ArchTag_,
  17. bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool V_colmajor_>
  18. struct CollectiveMainloopFwd {
  19. static constexpr int kStages = Stages;
  20. using ClusterShape = ClusterShape_;
  21. using TileShape_MNK = TileShape_MNK_;
  22. using Element = Element_;
  23. using ElementAccum = ElementAccum_;
  24. using ArchTag = ArchTag_;
  25. static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;;
  26. static constexpr bool Is_causal = Is_causal_;
  27. static constexpr bool Is_local = Is_local_;
  28. static constexpr bool Has_softcap = Has_softcap_;
  29. static constexpr bool Varlen = Varlen_;
  30. static constexpr bool V_colmajor = V_colmajor_;
  31. static constexpr bool Transpose_V = Is_FP8 && !V_colmajor;
  32. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  33. static_assert(ArchTag::kMinComputeCapability >= 90);
  34. static_assert(get<1>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1);
  35. static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 && !V_colmajor ? GMMA::Major::MN : GMMA::Major::K;
  36. static constexpr cute::GMMA::Major TmaMajorV = !V_colmajor ? GMMA::Major::MN : GMMA::Major::K;
  37. using AtomLayoutMNK = Layout<Shape<Int<get<0>(TileShape_MNK{}) / 64>, _1, _1>>;
  38. using TiledMma0 = decltype(cute::make_tiled_mma(
  39. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
  40. AtomLayoutMNK{}));
  41. using TiledMma1 = decltype(cute::make_tiled_mma(
  42. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
  43. GMMA::Major::K, MmaMajorV>(),
  44. AtomLayoutMNK{}));
  45. static constexpr int NumMmaThreads = size(TiledMma0{});
  46. static constexpr int NumProducerThreads = !Transpose_V ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup;
  47. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  48. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  49. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  50. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  51. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  52. using SmemLayoutK = decltype(tile_to_shape(
  53. SmemLayoutAtomK{},
  54. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  55. using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector<TmaMajorV, Element,
  56. decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  57. using SmemLayoutVt = decltype(tile_to_shape(
  58. SmemLayoutAtomVt{},
  59. make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{}),
  60. std::conditional_t<TmaMajorV == GMMA::Major::K, cute::Step<_1, _2, _3>, cute::Step<_2, _1, _3>>{}));
  61. using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector<MmaMajorV, Element,
  62. decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  63. using SmemLayoutVtMma = decltype(tile_to_shape(
  64. SmemLayoutAtomVtMma{},
  65. make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{}),
  66. std::conditional_t<MmaMajorV == GMMA::Major::K, cute::Step<_1, _2, _3>, cute::Step<_2, _1, _3>>{}));
  67. // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major.
  68. // For FP16/BF16 we don't do any transposing.
  69. static_assert(!Transpose_V || (kHeadDim % 32 == 0 && CUTE_STATIC_V(get<1>(TileShape_MNK{})) % 32 == 0));
  70. static constexpr bool kHeadDim_multiple_64 = kHeadDim % 64 == 0;
  71. // Either kHeadDim is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose),
  72. // or we need kBlockN to be a multiple of 64 (in which case we use a block size of 32 x 64 for the transpose).
  73. static_assert(!Transpose_V || (kHeadDim_multiple_64 || CUTE_STATIC_V(get<1>(TileShape_MNK{})) % 64 == 0));
  74. using LDSM_thread_shape = std::conditional_t<kHeadDim_multiple_64, Shape<_32, _4, _1, _1>, Shape<_16, _4, _1, _2>>;
  75. using LDSM_thread_stride = std::conditional_t<kHeadDim_multiple_64, Stride<_4, _1, _0, _0>, Stride<_4, _1, _0, _64>>;
  76. using LDSM_value_shape = Shape<_2, _2, _1, _4>;
  77. using LDSM_value_stride = Stride<_1, _2, _16, _4>;
  78. using LDSM_divide_shape = std::conditional_t<kHeadDim_multiple_64, Shape<_64, _8>, Shape<_32, _8>>;
  79. using S2RTiledCopyVt = decltype(make_tiled_copy(
  80. Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<LDSM_thread_shape, LDSM_thread_stride>{},
  81. Layout<LDSM_value_shape, LDSM_value_stride>{}));
  82. using STSM_thread_shape = std::conditional_t<kHeadDim_multiple_64, Shape<_8, _4, _4, _1>, Shape<_8, _4, _2, _2>>;
  83. using STSM_thread_stride = std::conditional_t<kHeadDim_multiple_64, Stride<_4, _1, _32, _0>, Stride<_4, _1, _32, _64>>;
  84. using STSM_value_shape = Shape<_1, _4, _2, _2>;
  85. using STSM_value_stride = Stride<_0, _1, _4, _8>;
  86. using STSM_divide_shape = Shape<_8, _16>;
  87. // These will not permute the columns of V (the kHeadDim dimension) but incur bank conflicts
  88. // so a little slower (e.g. 1150 TFLOPS for hdim 256 instead of 1200 TFLOPS).
  89. // Instead we will permute the cols of V, and un-permute the cols of O in the epilogue.
  90. // using STSM_value_shape = Shape<_2, _4, _1, _2>;
  91. // using STSM_value_stride = Stride<_4, _1, _0, _8>;
  92. // using STSM_divide_shape = Shape<_16, _16>;
  93. using R2STiledCopyV = decltype(make_tiled_copy(
  94. Copy_Atom<SM90_U32x4_STSM_N, Element>{}, Layout<STSM_thread_shape, STSM_thread_stride>{},
  95. Layout<STSM_value_shape, STSM_value_stride>{}));
  96. using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
  97. using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
  98. using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
  99. using StrideQK = cute::Stride<int64_t, _1, int64_t, int64_t>;
  100. using StrideV = std::conditional_t<!V_colmajor, StrideQK, cute::Stride<_1, int64_t, int64_t, int64_t>>;
  101. using TMA_Q = decltype(make_tma_copy_A_sm90(
  102. GmemTiledCopyQ{},
  103. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQK{}),
  104. SmemLayoutQ{},
  105. TileShape_MNK{},
  106. ClusterShape{}));
  107. using TMA_K = decltype(make_tma_copy_B_sm90(
  108. GmemTiledCopyKV{},
  109. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQK{}),
  110. take<0, 2>(SmemLayoutK{}),
  111. TileShape_MNK{},
  112. ClusterShape{})); // mcast along M mode for this N load, if any
  113. using TMA_V = decltype(make_tma_copy(
  114. GmemTiledCopyKV{},
  115. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})),
  116. take<0, 2>(SmemLayoutVt{}),
  117. select<2, 1>(TileShape_MNK{}),
  118. size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
  119. // Set the bytes transferred in this TMA transaction (may involve multiple issues)
  120. static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
  121. static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
  122. static constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v<Element> / 8);
  123. static_assert(TmaTransactionBytesK == TmaTransactionBytesV);
  124. using MainloopPipelineK = typename cutlass::PipelineTmaAsync<kStages>;
  125. using MainloopPipelineV = std::conditional_t<!Transpose_V, typename cutlass::PipelineTmaAsync<kStages>, typename cutlass::PipelineAsync<kStages>>;
  126. using MainloopPipelineVt = typename cutlass::PipelineTmaAsync<kStages>;
  127. using PipelineState = cutlass::PipelineState<kStages>;
  128. struct TensorStorageNoTranspose : cute::aligned_struct<128> {
  129. cute::array_aligned<Element, cute::cosize_v<SmemLayoutVt>> smem_v;
  130. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  131. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  132. };
  133. static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{});
  134. static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{});
  135. static_assert(SmemAlignmentVt >= 128 and SmemAlignmentV >= 128, "Require at least 128B alignment");
  136. struct TensorStorageTransposeV : cute::aligned_struct<cute::max(SmemAlignmentVt, SmemAlignmentV)> {
  137. cute::array_aligned<Element, cute::cosize_v<SmemLayoutVtMma>, SmemAlignmentV> smem_v;
  138. cute::array_aligned<Element, cute::cosize_v<SmemLayoutVt>, SmemAlignmentVt> smem_vt;
  139. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  140. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  141. };
  142. using TensorStorage = std::conditional_t<!Transpose_V, TensorStorageNoTranspose, TensorStorageTransposeV>;
  143. // These are tuned for speed. They don't affect correctness.
  144. static constexpr bool UseSchedulerBarrier = !Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128;
  145. static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor);
  146. // Host side kernel arguments
  147. struct Arguments {
  148. Element const* ptr_Q;
  149. ShapeQKV const shape_Q;
  150. StrideQK const stride_Q;
  151. Element const* ptr_K;
  152. ShapeQKV const shape_K;
  153. StrideQK const stride_K;
  154. Element const* ptr_V;
  155. StrideV const stride_V;
  156. float const softmax_scale;
  157. float const* ptr_q_scale = nullptr, *ptr_k_scale = nullptr, *ptr_v_scale = nullptr;
  158. int const window_size_left = -1, window_size_right = -1;
  159. float const softcap_val;
  160. int const* cu_seqlens_q = nullptr;
  161. int const* cu_seqlens_k = nullptr;
  162. int const* seqused_q = nullptr;
  163. int const* seqused_k = nullptr;
  164. };
  165. // Device side kernel params
  166. struct Params {
  167. ShapeQKV const shape_Q;
  168. ShapeQKV const shape_K;
  169. cutlass::FastDivmod qhead_per_khead_divmod;
  170. TMA_Q tma_load_Q;
  171. TMA_K tma_load_K;
  172. TMA_V tma_load_V;
  173. float const softmax_scale_log2;
  174. float const* ptr_q_scale = nullptr, *ptr_k_scale = nullptr, *ptr_v_scale = nullptr;
  175. float const softcap_val;
  176. int const window_size_left, window_size_right;
  177. int const* cu_seqlens_q = nullptr;
  178. int const* cu_seqlens_k = nullptr;
  179. int const* seqused_q = nullptr;
  180. int const* seqused_k = nullptr;
  181. };
  182. static Params
  183. to_underlying_arguments(Arguments const& args) {
  184. Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q);
  185. TMA_Q tma_load_Q = make_tma_copy_A_sm90(
  186. GmemTiledCopyQ{},
  187. mQ,
  188. SmemLayoutQ{},
  189. TileShape_MNK{},
  190. ClusterShape{}); // no mcast for Q
  191. Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K);
  192. TMA_K tma_load_K = make_tma_copy_B_sm90(
  193. GmemTiledCopyKV{},
  194. mK,
  195. take<0, 2>(SmemLayoutK{}),
  196. TileShape_MNK{},
  197. ClusterShape{}); // mcast along M mode for this N load, if any
  198. Tensor mVt = make_tensor(make_gmem_ptr(args.ptr_V), select<1, 0, 2, 3>(args.shape_K), select<1, 0, 2, 3>(args.stride_V));
  199. TMA_V tma_load_V = make_tma_copy(
  200. GmemTiledCopyKV{},
  201. mVt,
  202. take<0, 2>(SmemLayoutVt{}),
  203. select<2, 1>(TileShape_MNK{}),
  204. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  205. if constexpr (Varlen) {
  206. assert(args.cu_seqlens_q != nullptr && args.cu_seqlens_k != nullptr);
  207. }
  208. // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val.
  209. // Right after this, we multiply by log2(e) before applying exp2.
  210. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val
  211. // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e)
  212. // (assigning it to params.softmax_scale_log2).
  213. // TODO: this currently doesn't work with FP8 scaling
  214. return {args.shape_Q, args.shape_K,
  215. cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
  216. tma_load_Q, tma_load_K, tma_load_V,
  217. !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E),
  218. args.ptr_q_scale, args.ptr_k_scale, args.ptr_v_scale,
  219. !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val,
  220. args.window_size_left, args.window_size_right,
  221. args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k};
  222. }
  223. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  224. CUTLASS_DEVICE
  225. static void prefetch_tma_descriptors(Params const& params) {
  226. cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor());
  227. cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor());
  228. cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor());
  229. }
  230. CUTLASS_DEVICE
  231. int get_seqlen_q(Params const& params, int bidb) {
  232. if constexpr (!Varlen) {
  233. return get<0>(params.shape_Q);
  234. } else {
  235. return params.seqused_q ? params.seqused_q[bidb] : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb];
  236. }
  237. }
  238. CUTLASS_DEVICE
  239. int get_seqlen_k(Params const& params, int bidb) {
  240. if constexpr (!Varlen) {
  241. return get<0>(params.shape_K);
  242. } else {
  243. return params.seqused_k ? params.seqused_k[bidb] : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb];
  244. }
  245. }
  246. CUTLASS_DEVICE
  247. int get_n_block_max(Params const& params, int m_block, int bidb) {
  248. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  249. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  250. int const seqlen_k = get_seqlen_k(params, bidb);
  251. int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
  252. if constexpr (Is_causal || Is_local) {
  253. int const seqlen_q = get_seqlen_q(params, bidb);
  254. n_block_max = std::min(n_block_max,
  255. cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q + params.window_size_right, kBlockN));
  256. }
  257. return n_block_max;
  258. }
  259. CUTLASS_DEVICE
  260. int get_n_block_min(Params const& params, int m_block, int bidb) {
  261. if (!Is_local) {
  262. return 0;
  263. } else {
  264. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  265. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  266. int const seqlen_k = get_seqlen_k(params, bidb);
  267. int const seqlen_q = get_seqlen_q(params, bidb);
  268. return std::max(int(0), (m_block * kBlockM + seqlen_k - seqlen_q - params.window_size_left) / kBlockN);
  269. }
  270. }
  271. template <typename SchedulerPrefetch, typename SharedStorage>
  272. CUTLASS_DEVICE void
  273. load(Params const& params,
  274. MainloopPipelineK pipeline_k,
  275. MainloopPipelineV pipeline_v,
  276. PipelineState& smem_pipe_write,
  277. SharedStorage &shared_storage,
  278. SchedulerPrefetch const& scheduler_prefetch,
  279. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  280. int work_idx
  281. ) {
  282. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});
  283. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  284. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{});
  285. auto [m_block, bidh, bidb] = block_coord;
  286. int bidh_kv = params.qhead_per_khead_divmod.divide(bidh);
  287. // Prepare the TMA loads
  288. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  289. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  290. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  291. Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !Varlen ? bidb : 0);
  292. Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !Varlen ? bidb : 0);
  293. Tensor mVt = params.tma_load_V.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K))(_, _, bidh_kv, !Varlen ? bidb : 0);
  294. Tensor gQ = local_tile(domain_offset(make_coord(!Varlen ? 0 : params.cu_seqlens_q[bidb], _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  295. Tensor gK = local_tile(domain_offset(make_coord(!Varlen ? 0 : params.cu_seqlens_k[bidb], _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  296. Tensor gVt = local_tile(domain_offset(make_coord(_0{}, !Varlen ? 0 : params.cu_seqlens_k[bidb]), mVt), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _)
  297. Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
  298. Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
  299. auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, _0{}, Layout<_1>{},
  300. group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
  301. auto [tKgK, tKsK] = tma_partition(params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
  302. group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE)
  303. auto [tVgVt, tVsVt] = tma_partition(params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
  304. group_modes<0, 2>(sVt), group_modes<0, 2>(gVt)); // (TMA, k), (TMA, PIPE)
  305. uint16_t mcast_mask_kv = 0;
  306. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
  307. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  308. for (int m = 0; m < size<0>(block_layout); ++m) {
  309. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  310. }
  311. }
  312. int n_block_max = get_n_block_max(params, m_block, bidb);
  313. int n_block_min = get_n_block_min(params, m_block, bidb);
  314. int n_block = n_block_max - 1;
  315. int lane_predicate = cute::elect_one_sync();
  316. if (lane_predicate) {
  317. pipeline_k.producer_acquire(smem_pipe_write);
  318. if constexpr (size(ClusterShape{}) == 1) {
  319. copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST),
  320. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  321. } else {
  322. copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  323. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  324. }
  325. }
  326. // Wait for the MMA warpgroups to say that smem_q is ready
  327. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  328. if (lane_predicate) {
  329. shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  330. copy(params.tma_load_Q.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, TMA::CacheHintSm90::EVICT_FIRST),
  331. tQgQ, tQsQ);
  332. }
  333. // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem
  334. // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the
  335. // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O.
  336. shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2);
  337. if (lane_predicate) {
  338. // CUTLASS_PRAGMA_NO_UNROLL
  339. #pragma unroll 2
  340. for (; n_block > n_block_min; --n_block) {
  341. PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind
  342. ++smem_pipe_write;
  343. pipeline_k.producer_acquire(smem_pipe_write);
  344. if constexpr (size(ClusterShape{}) == 1) {
  345. copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST),
  346. tKgK(_, n_block - 1), tKsK(_, smem_pipe_write.index()));
  347. } else {
  348. copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  349. tKgK(_, n_block - 1), tKsK(_, smem_pipe_write.index()));
  350. }
  351. pipeline_v.producer_acquire(smem_pipe_write_v);
  352. if constexpr (size(ClusterShape{}) == 1) {
  353. copy(params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST),
  354. tVgVt(_, n_block), tVsVt(_, smem_pipe_write_v.index()));
  355. } else {
  356. copy(params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
  357. tVgVt(_, n_block), tVsVt(_, smem_pipe_write_v.index()));
  358. }
  359. }
  360. }
  361. scheduler_prefetch();
  362. if (lane_predicate) {
  363. pipeline_v.producer_acquire(smem_pipe_write);
  364. if constexpr (size(ClusterShape{}) == 1) {
  365. copy(params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST),
  366. tVgVt(_, n_block), tVsVt(_, smem_pipe_write.index()));
  367. } else {
  368. copy(params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  369. tVgVt(_, n_block), tVsVt(_, smem_pipe_write.index()));
  370. }
  371. ++smem_pipe_write;
  372. }
  373. }
  374. template <typename SchedulerPrefetch, typename SharedStorage>
  375. CUTLASS_DEVICE void
  376. load_fp8_transpose_V(
  377. Params const& params,
  378. MainloopPipelineK pipeline_k,
  379. MainloopPipelineV pipeline_v,
  380. MainloopPipelineVt pipeline_vt,
  381. PipelineState& smem_pipe_write,
  382. SharedStorage &shared_storage,
  383. SchedulerPrefetch const& scheduler_prefetch,
  384. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  385. int work_idx
  386. ) {
  387. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});
  388. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  389. // as_position_independent_swizzle_tensor makes address calculation easier when we do LDSM & STSM to transpose.
  390. // But it requires smem_vt and smem_v to be aligned to e.g 512 bytes.
  391. Tensor sVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{}));
  392. Tensor sV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}));
  393. auto [m_block, bidh, bidb] = block_coord;
  394. int bidh_kv = params.qhead_per_khead_divmod.divide(bidh);
  395. // Prepare the TMA loads
  396. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  397. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  398. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  399. Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !Varlen ? bidb : 0);
  400. Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !Varlen ? bidb : 0);
  401. Tensor mVt = params.tma_load_V.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K))(_, _, bidh_kv, !Varlen ? bidb : 0);
  402. Tensor gQ = local_tile(domain_offset(make_coord(!Varlen ? 0 : params.cu_seqlens_q[bidb], _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  403. Tensor gK = local_tile(domain_offset(make_coord(!Varlen ? 0 : params.cu_seqlens_k[bidb], _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  404. Tensor gVt = local_tile(domain_offset(make_coord(_0{}, !Varlen ? 0 : params.cu_seqlens_k[bidb]), mVt), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _)
  405. auto block_tma_Q = params.tma_load_Q.get_slice(_0{});
  406. Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA)
  407. Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA)
  408. // tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually
  409. auto block_tma_K = params.tma_load_K.get_slice(cluster_local_block_id.x);
  410. Tensor tKgK = group_modes<0, 3>(block_tma_K.partition_S(gK)); // (TMA, k)
  411. Tensor tKsK = group_modes<0, 3>(block_tma_K.partition_D(sK)); // (TMA, PIPE)
  412. auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x);
  413. Tensor tVgVt = group_modes<0, 3>(block_tma_V.partition_S(gVt)); // (TMA, k)
  414. Tensor tVsVt = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE)
  415. uint16_t mcast_mask_kv = 0;
  416. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
  417. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  418. for (int m = 0; m < size<0>(block_layout); ++m) {
  419. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  420. }
  421. }
  422. // Set up for transposing V
  423. S2RTiledCopyVt s2r_tiled_copy_vt;
  424. R2STiledCopyV r2s_tiled_copy_v;
  425. auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(threadIdx.x % NumProducerThreads);
  426. auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(threadIdx.x % NumProducerThreads);
  427. // flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8, kStages)
  428. Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32, kStages)
  429. // flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64), kStages)
  430. Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64), kStages)
  431. CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_));
  432. CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_));
  433. CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_));
  434. CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_));
  435. CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_));
  436. CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_));
  437. // Faster to have 2 LDSM.T, byte permute, STSM for better ILP
  438. static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1;
  439. Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_) - 1>(tTranssVt_), Shape<Underscore, Int<Transpose_ILP>>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages)
  440. Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_) - 1>(tTranssV_), Shape<Underscore, Int<Transpose_ILP>>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages)
  441. auto transpose_V = [&](int stage) {
  442. #pragma unroll
  443. for (int i = 0; i < size<1, 1>(tTranssVt); ++i) {
  444. Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{}), _0{}));
  445. static_assert(size<0>(tTransrV) == 16);
  446. Tensor tTransrV_64 = recast<uint2>(tTransrV);
  447. cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i), stage), tTransrV);
  448. #pragma unroll
  449. for (int j = 0; j < size(tTransrV_64); ++j) {
  450. uint32_t upper = tTransrV_64[j].x;
  451. uint32_t lower = tTransrV_64[j].y;
  452. tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420);
  453. tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531);
  454. }
  455. cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i), stage));
  456. }
  457. };
  458. int n_block_max = get_n_block_max(params, m_block, bidb);
  459. int n_block_min = get_n_block_min(params, m_block, bidb);
  460. int n_block = n_block_max - 1;
  461. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  462. int lane_predicate = cute::elect_one_sync();
  463. if (warp_idx_in_warpgroup == 0) {
  464. if (lane_predicate) {
  465. pipeline_vt.producer_acquire(smem_pipe_write);
  466. if constexpr (size(ClusterShape{}) == 1) {
  467. copy(params.tma_load_V.with(*pipeline_vt.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST),
  468. tVgVt(_, n_block), tVsVt(_, smem_pipe_write.index()));
  469. } else {
  470. copy(params.tma_load_V.with(*pipeline_vt.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  471. tVgVt(_, n_block), tVsVt(_, smem_pipe_write.index()));
  472. }
  473. pipeline_k.producer_acquire(smem_pipe_write);
  474. if constexpr (size(ClusterShape{}) == 1) {
  475. copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST),
  476. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  477. } else {
  478. copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  479. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  480. }
  481. }
  482. // Wait for the MMA warpgroups to say that smem_q is ready
  483. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  484. if (lane_predicate) {
  485. shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  486. copy(params.tma_load_Q.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, TMA::CacheHintSm90::EVICT_FIRST),
  487. tQgQ, tQsQ);
  488. }
  489. }
  490. --n_block;
  491. // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem
  492. // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the
  493. // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O.
  494. // if (blockIdx.x == 0 && threadIdx.x % 32 == 0) { printf("tidx = %d, Producer: before barrier_O.wait\n", threadIdx.x); }
  495. shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2);
  496. // CUTLASS_PRAGMA_NO_UNROLL
  497. #pragma unroll 1
  498. for (; n_block >= n_block_min; --n_block) {
  499. PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind
  500. ++smem_pipe_write;
  501. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  502. pipeline_vt.producer_acquire(smem_pipe_write);
  503. if constexpr (size(ClusterShape{}) == 1) {
  504. copy(params.tma_load_V.with(*pipeline_vt.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST),
  505. tVgVt(_, n_block), tVsVt(_, smem_pipe_write.index()));
  506. } else {
  507. copy(params.tma_load_V.with(*pipeline_vt.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  508. tVgVt(_, n_block), tVsVt(_, smem_pipe_write.index()));
  509. }
  510. pipeline_k.producer_acquire(smem_pipe_write);
  511. if constexpr (size(ClusterShape{}) == 1) {
  512. copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST),
  513. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  514. } else {
  515. copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  516. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  517. }
  518. }
  519. // Instead of maintaining smem_pipe_read_v as a separate variable, we can just use smem_pipe_write_v,
  520. // and exploit the invariance that smem_pipe_write_v.phase() == smem_pipe_read_v.phase() ^ 1.
  521. // This saves 1 or 2 registers.
  522. PipelineState smem_pipe_read_v{smem_pipe_write_v.index(), smem_pipe_write_v.phase() ^ 1, smem_pipe_write_v.count()};
  523. pipeline_vt.consumer_wait(smem_pipe_read_v);
  524. pipeline_v.producer_acquire(smem_pipe_write_v);
  525. transpose_V(smem_pipe_write_v.index());
  526. // SMEM fence to make sure V is transposed before math
  527. cutlass::arch::fence_view_async_shared();
  528. pipeline_v.producer_commit(smem_pipe_write_v);
  529. // PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized before calling
  530. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
  531. pipeline_vt.consumer_release(smem_pipe_read_v);
  532. }
  533. scheduler_prefetch();
  534. PipelineState smem_pipe_read_v{smem_pipe_write.index(), smem_pipe_write.phase() ^ 1, smem_pipe_write.count()};
  535. pipeline_vt.consumer_wait(smem_pipe_read_v);
  536. pipeline_v.producer_acquire(smem_pipe_write);
  537. transpose_V(smem_pipe_write.index());
  538. // SMEM fence to make sure V is transposed before math
  539. cutlass::arch::fence_view_async_shared();
  540. pipeline_v.producer_commit(smem_pipe_write);
  541. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
  542. pipeline_vt.consumer_release(smem_pipe_read_v);
  543. ++smem_pipe_write;
  544. }
  545. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  546. CUTLASS_DEVICE void
  547. load_tail(MainloopPipelineK pipeline_k, MainloopPipelineV pipeline_v, PipelineState& smem_pipe_write) {
  548. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  549. int lane_predicate = cute::elect_one_sync();
  550. // Issue the epilogue waits
  551. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  552. /* This helps avoid early exit of blocks in Cluster
  553. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  554. * then would just be acquired since the phase was still inverted from make_producer_start_state
  555. */
  556. pipeline_k.producer_tail(smem_pipe_write);
  557. pipeline_v.producer_tail(smem_pipe_write);
  558. }
  559. }
  560. CUTLASS_DEVICE void
  561. load_tail(MainloopPipelineK pipeline_k, MainloopPipelineV pipeline_v, MainloopPipelineVt pipeline_vt,
  562. PipelineState& smem_pipe_write) {
  563. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  564. int lane_predicate = cute::elect_one_sync();
  565. // Issue the epilogue waits
  566. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  567. /* This helps avoid early exit of blocks in Cluster
  568. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  569. * then would just be acquired since the phase was still inverted from make_producer_start_state
  570. */
  571. pipeline_k.producer_tail(smem_pipe_write);
  572. pipeline_v.producer_tail(smem_pipe_write);
  573. pipeline_vt.producer_tail(smem_pipe_write);
  574. }
  575. }
  576. CUTLASS_DEVICE void
  577. warp_scheduler_barrier_sync() {
  578. if constexpr (UseSchedulerBarrier) {
  579. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
  580. }
  581. }
  582. CUTLASS_DEVICE void
  583. warp_scheduler_barrier_arrive() {
  584. if constexpr (!UseSchedulerBarrier) { return; }
  585. static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
  586. if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
  587. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
  588. } else {
  589. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
  590. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
  591. }
  592. }
  593. CUTLASS_DEVICE void
  594. mma_init() {
  595. // Tell producer (warp 0) that smem_q is ready
  596. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  597. if constexpr (!UseSchedulerBarrier) { return; }
  598. static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
  599. if (cutlass::canonical_warp_group_idx() > 1) {
  600. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
  601. }
  602. if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
  603. if (cutlass::canonical_warp_group_idx() > 2) {
  604. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
  605. }
  606. }
  607. }
  608. template <typename SharedStorage, typename FrgTensorO, typename Softmax>
  609. CUTLASS_DEVICE void
  610. mma(Params const& params,
  611. MainloopPipelineK pipeline_k,
  612. MainloopPipelineV pipeline_v,
  613. PipelineState& smem_pipe_read,
  614. FrgTensorO& tOrO,
  615. Softmax& softmax,
  616. int thread_idx,
  617. int work_idx,
  618. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  619. SharedStorage& shared_storage
  620. ) {
  621. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  622. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  623. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  624. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});
  625. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  626. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{});
  627. static_assert(stride<0>(typename TiledMma0::ALayout{}) == 0 and
  628. stride<0>(typename TiledMma0::BLayout{}) == 0 and
  629. size<0>(typename TiledMma0::ALayout{}) == cutlass::NumThreadsPerWarpGroup and
  630. size<0>(typename TiledMma0::BLayout{}) == cutlass::NumThreadsPerWarpGroup,
  631. "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
  632. constexpr int MmaWarpGroups = size(TiledMma0{}) / cutlass::NumThreadsPerWarpGroup;
  633. Layout warp_group_thread_layout = make_layout(make_shape(Int<MmaWarpGroups>{}),
  634. make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));
  635. int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
  636. TiledMma0 tiled_mma0;
  637. TiledMma1 tiled_mma1;
  638. auto wg_mma0 = tiled_mma0.get_slice(warp_group_thread_layout(warp_group_idx));
  639. auto thread_mma0 = tiled_mma0.get_thread_slice(thread_idx);
  640. auto wg_mma1 = tiled_mma1.get_slice(warp_group_thread_layout(warp_group_idx));
  641. // Allocate "fragments/descriptors"
  642. Tensor tSrQ = wg_mma0.partition_fragment_A(sQ);
  643. Tensor tSrK = wg_mma0.partition_fragment_B(sK);
  644. Tensor tOrV = wg_mma1.partition_fragment_B(sVt);
  645. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  646. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  647. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  648. };
  649. // clear(tOrO);
  650. tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
  651. int m_block = get<0>(block_coord);
  652. int bidb = get<2>(block_coord);
  653. int const seqlen_q = get_seqlen_q(params, bidb);
  654. int const seqlen_k = get_seqlen_k(params, bidb);
  655. int n_block_max = get_n_block_max(params, m_block, bidb);
  656. int n_block_min = get_n_block_min(params, m_block, bidb);
  657. int n_block = n_block_max - 1;
  658. auto causal_local_mask_fn = [&](auto& tSrS, int const n_block, auto need_seqlenk_masking_type, auto is_causal_type, auto is_local_type) {
  659. constexpr bool Need_seqlenk_masking = decltype(need_seqlenk_masking_type)::value;
  660. constexpr bool Is_causal = decltype(is_causal_type)::value;
  661. constexpr bool Is_local = decltype(is_local_type)::value;
  662. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  663. Tensor tScS = thread_mma0.partition_C(cS);
  664. if constexpr (!Is_causal && !Is_local) {
  665. if constexpr (Need_seqlenk_masking) { // Just masking based on col
  666. #pragma unroll
  667. for (int i = 0; i < size(tSrS); ++i) {
  668. if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
  669. }
  670. }
  671. } else { // mask based on both row and col
  672. int causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
  673. if constexpr (Is_causal) {
  674. #pragma unroll
  675. for (int i = 0; i < size(tSrS); ++i) {
  676. // using std::min is faster than doing col >= limit0 or col >= limit1
  677. // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the
  678. // right hand side can be negative and might be converted to a very large unsigned integer.
  679. int col_limit_right = !Need_seqlenk_masking
  680. ? int(get<0>(tScS(i))) + causal_row_offset
  681. : std::min(int(get<0>(tScS(i))) + causal_row_offset, seqlen_k - n_block * kBlockN);
  682. if (int(get<1>(tScS(i))) >= col_limit_right) { tSrS(i) = -INFINITY; }
  683. }
  684. } else {
  685. int local_row_offset_right = causal_row_offset + params.window_size_right;
  686. int local_row_offset_left = causal_row_offset - 1 - params.window_size_left;
  687. #pragma unroll
  688. for (int i = 0; i < size(tSrS); ++i) {
  689. int col_limit_right = !Need_seqlenk_masking
  690. ? int(get<0>(tScS(i))) + local_row_offset_right
  691. : __viaddmin_s32(int(get<0>(tScS(i))), local_row_offset_right, seqlen_k - n_block * kBlockN);
  692. int col_limit_left = int(get<0>(tScS(i))) + local_row_offset_left;
  693. if (int(get<1>(tScS(i))) >= col_limit_right || int(get<1>(tScS(i))) < col_limit_left) {
  694. tSrS(i) = -INFINITY;
  695. }
  696. }
  697. }
  698. }
  699. };
  700. typename cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.pipelines.barrier_Q.try_wait(work_idx % 2));
  701. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.pipelines.barrier_Q.wait(work_idx % 2); }
  702. if constexpr (true) {
  703. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  704. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(tSrS); }
  705. consumer_wait(pipeline_k, smem_pipe_read);
  706. warp_scheduler_barrier_sync();
  707. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  708. warp_scheduler_barrier_arrive();
  709. if (work_idx != 0) {
  710. int lane_predicate = cute::elect_one_sync();
  711. int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
  712. if (warp_idx_sync == NumMmaThreads / cutlass::NumThreadsPerWarp - 1 && lane_predicate) {
  713. if constexpr (!Varlen) { tma_store_wait<0>(); }
  714. #pragma unroll
  715. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  716. shared_storage.pipelines.barrier_O.arrive(cta_id, lane_predicate);
  717. }
  718. }
  719. }
  720. warpgroup_wait<0>();
  721. pipeline_k.consumer_release(smem_pipe_read);
  722. // This needs to happen before masking since if we apply after masking, softcapping can turn
  723. // -inf to e.g. -50.0, which can affect the attention softmax.
  724. if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); }
  725. causal_local_mask_fn(tSrS, n_block, cute::bool_constant<true>{} /*need_seqlenk_masking*/, cute::bool_constant<Is_causal>{}, cute::bool_constant<Is_local>{});
  726. Tensor scores_scale = softmax.template max_get_scale</*Is_first=*/true, /*Check_inf=*/true>(tSrS);
  727. softmax.template online_softmax</*Is_first=*/true, /*Check_inf=*/true>(tSrS);
  728. if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); }
  729. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<TiledMma1>(tSrS.layout()));
  730. if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); }
  731. // Each step does gemm0 for iter n_block - 1, gemm1 for iter n_block, and softmax for iter n_block - 1.
  732. auto fwd_step = [&](int n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) {
  733. static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value;
  734. static constexpr bool Check_inf = decltype(check_inf_type)::value;
  735. PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count());
  736. ++smem_pipe_read;
  737. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  738. consumer_wait(pipeline_k, smem_pipe_read);
  739. warp_scheduler_barrier_sync();
  740. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  741. if constexpr (RescaleOBeforeGemm && !Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); }
  742. consumer_wait(pipeline_v, smem_pipe_read_v);
  743. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  744. warp_scheduler_barrier_arrive();
  745. warpgroup_wait<1>();
  746. pipeline_k.consumer_release(smem_pipe_read); // release K
  747. if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); }
  748. mask_fn(tSrS, n_block - 1);
  749. cute::copy(softmax.template max_get_scale</*Is_first=*/false, Check_inf>(tSrS), scores_scale);
  750. softmax.template online_softmax</*Is_first=*/false, Check_inf>(tSrS);
  751. warpgroup_wait<0>();
  752. pipeline_v.consumer_release(smem_pipe_read_v); // release V
  753. if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); }
  754. cute::copy(make_tensor(convert_type<Element>(tSrS).data(), tOrP.layout()), tOrP);
  755. if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); }
  756. if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); }
  757. };
  758. if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking
  759. auto mask_fn = [&](auto& tSrS, int n_block) { causal_local_mask_fn(tSrS, n_block, cute::bool_constant<false>{} /*need_seqlenk_masking*/, cute::bool_constant<Is_causal>{}, cute::bool_constant<Is_local>{}); };
  760. constexpr int n_masking_steps = cute::ceil_div(kBlockM, kBlockN) + 1;
  761. #pragma unroll
  762. for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > n_block_min; ++masking_step, --n_block) {
  763. if (masking_step == 0) {
  764. fwd_step(n_block, mask_fn, cute::bool_constant<true>{} /*is_first_iter*/, cute::bool_constant<true>{} /*check_inf*/);
  765. } else {
  766. fwd_step(n_block, mask_fn, cute::bool_constant<false>{} /*is_first_iter*/, cute::bool_constant<true>{} /*check_inf*/);
  767. }
  768. }
  769. }
  770. static constexpr int n_local_left_steps = !Is_local ? 0 : cute::ceil_div(kBlockM, kBlockN) + 1;
  771. auto no_mask_fn = [](auto& tSrS, int n_block) { };
  772. #pragma unroll 1
  773. for (; n_block > n_block_min + n_local_left_steps; --n_block) {
  774. fwd_step(n_block, no_mask_fn, cute::bool_constant<false>{} /*is_first_iter*/, cute::bool_constant<false>{} /*check_inf*/);
  775. }
  776. // Separate masking iterations on the left for local attention
  777. if constexpr (Is_local) {
  778. auto local_mask_fn = [&](auto& tSrS, int n_block) { causal_local_mask_fn(tSrS, n_block, cute::bool_constant<false>{} /*need_seqlenk_masking*/, cute::bool_constant<false>{} /*is_causal*/, cute::bool_constant<Is_local>{}); };
  779. #pragma unroll 1
  780. for (; n_block > n_block_min; --n_block) {
  781. fwd_step(n_block, local_mask_fn, cute::bool_constant<false>{} /*is_first_iter*/, cute::bool_constant<Is_local>{} /*check_inf*/);
  782. }
  783. }
  784. // Tell warp 0 that smem_q is ready
  785. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  786. if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); }
  787. consumer_wait(pipeline_v, smem_pipe_read);
  788. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  789. cute::copy(softmax.finalize(!Is_FP8 || params.ptr_v_scale == nullptr ? 1.f : *params.ptr_v_scale), scores_scale);
  790. warpgroup_wait<0>();
  791. pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang
  792. softmax.rescale_o(tOrO, scores_scale);
  793. if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); }
  794. ++smem_pipe_read;
  795. } else {
  796. // WIP
  797. if (work_idx != 0) {
  798. int lane_predicate = cute::elect_one_sync();
  799. int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
  800. if (warp_idx_sync == NumMmaThreads / cutlass::NumThreadsPerWarp - 1 && lane_predicate) {
  801. if constexpr (!Varlen) { tma_store_wait<0>(); }
  802. #pragma unroll
  803. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  804. shared_storage.pipelines.barrier_O.arrive(cta_id, lane_predicate);
  805. }
  806. }
  807. }
  808. #pragma unroll 1
  809. for (; n_block >= 0; --n_block) {
  810. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  811. consumer_wait(pipeline_k, smem_pipe_read);
  812. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  813. warpgroup_wait<0>();
  814. pipeline_k.consumer_release(smem_pipe_read); // release K
  815. Tensor scores_scale = softmax.template max_get_scale</*Is_first=*/false>(tSrS);
  816. warp_scheduler_barrier_sync();
  817. softmax.template online_softmax</*Is_first=*/false>(tSrS);
  818. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<TiledMma1>(tSrS.layout()));
  819. warp_scheduler_barrier_arrive();
  820. if constexpr (Is_FP8) { flash::permute_Aregs_fp8(tOrP); }
  821. softmax.rescale_o(tOrO, scores_scale);
  822. consumer_wait(pipeline_v, smem_pipe_read);
  823. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  824. warpgroup_wait<0>();
  825. pipeline_v.consumer_release(smem_pipe_read); // release V
  826. ++smem_pipe_read;
  827. }
  828. // Tell warp 0 that smem_q is ready
  829. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  830. Tensor scores_scale = softmax.finalize();
  831. softmax.rescale_o(tOrO, scores_scale);
  832. }
  833. }
  834. };
  835. } // namespace flash