mainloop_fwd_sm90_tma_gmma_ws.hpp 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include "cutlass/pipeline/pipeline.hpp"
  10. #include "cute/tensor.hpp"
  11. #include "cutlass/gemm/collective/collective_builder.hpp"
  12. #include "named_barrier.hpp"
  13. #include "utils.h"
  14. namespace flash {
  15. using namespace cute;
  16. template <typename Ktraits, bool Is_causal, typename Seqlen_traits>
  17. struct CollectiveMainloopFwd {
  18. using Element = typename Ktraits::Element;
  19. using TileShape_MNK = typename Ktraits::TileShape_MNK;
  20. using ClusterShape = typename Ktraits::ClusterShape_MNK;
  21. static constexpr int kStages = Ktraits::kStages;
  22. static constexpr int kHeadDim = Ktraits::kHeadDim;
  23. using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
  24. using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
  25. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  26. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  27. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  28. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  29. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  30. using SmemLayoutK =
  31. decltype(tile_to_shape(SmemLayoutAtomK{},
  32. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  33. using SmemLayoutAtomVFp8 = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  34. decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  35. using SmemLayoutVFp8 =
  36. decltype(tile_to_shape(SmemLayoutAtomVFp8{},
  37. make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
  38. using SmemLayoutVFp16 = SmemLayoutK;
  39. // Note this is the transpose in terms of the view, not in terms of memory.
  40. using SmemLayoutVtFp16 =
  41. decltype(cute::composition(SmemLayoutVFp16{},
  42. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  43. make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutVFp16{}(_, _, _0{}))>{}))));
  44. using SmemLayoutV = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(SmemLayoutVFp8{}, SmemLayoutVFp16{}));
  45. using SmemLayoutVt = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(SmemLayoutVFp8{}, SmemLayoutVtFp16{}));
  46. // Dummy S layout for getting the shape for GEMM-II.
  47. using SmemLayoutAtomS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  48. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  49. using SmemLayoutS =
  50. decltype(tile_to_shape(SmemLayoutAtomS{},
  51. make_shape(shape<0>(TileShape_MNK{}), shape<1>(TileShape_MNK{}))));
  52. // using SmemLayoutAtomVt = cute::GMMA::Layout_MN_SW128_Atom<Element>;
  53. // using SmemLayoutVt =
  54. // decltype(tile_to_shape(SmemLayoutAtomVt{},
  55. // make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{}),
  56. // Step<_2, _1, _3>{})); // This gives correct results, without Step it's wrong
  57. // using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::MN, Element,
  58. // decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  59. // using SmemLayoutVt =
  60. // decltype(tile_to_shape(SmemLayoutAtomVt{},
  61. // make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
  62. // using SmemLayoutAtomVTMA = cute::GMMA::Layout_K_SW128_Atom<Element>;
  63. // using SmemLayoutVTMA =
  64. // decltype(tile_to_shape(SmemLayoutAtomVTMA{},
  65. // make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  66. using TMA_Q = decltype(make_tma_copy(
  67. GmemTiledCopyQ{},
  68. make_tensor(
  69. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  70. repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
  71. typename Seqlen_traits::StrideT{}
  72. ),
  73. SmemLayoutQ{},
  74. select<0, 2>(TileShape_MNK{}),
  75. _1{})); // no mcast for Q
  76. using TMA_KV = decltype(make_tma_copy(
  77. GmemTiledCopyKV{},
  78. make_tensor(
  79. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  80. repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
  81. typename Seqlen_traits::StrideT{}
  82. ),
  83. take<0, 2>(SmemLayoutK{}),
  84. select<1, 2>(TileShape_MNK{}),
  85. size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
  86. //
  87. using TileShapeVFP8 = decltype(make_shape(cute::get<2>(TileShape_MNK{}), cute::get<1>(TileShape_MNK{})));
  88. using TileShapeVFP16 = decltype(make_shape(cute::get<1>(TileShape_MNK{}), cute::get<2>(TileShape_MNK{})));
  89. using TileShapeV = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(TileShapeVFP8{}, TileShapeVFP16{}));
  90. using TMA_VFP8 = decltype(make_tma_copy(
  91. GmemTiledCopyKV{},
  92. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}),
  93. take<0, 2>(SmemLayoutV{}),
  94. TileShapeV{},
  95. size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
  96. using TMA_V = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(TMA_VFP8{}, TMA_KV{}));
  97. static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
  98. using MainloopPipeline = typename Ktraits::MainloopPipeline;
  99. using PipelineParams = typename MainloopPipeline::Params;
  100. using PipelineState = typename MainloopPipeline::PipelineState;
  101. // Set the bytes transferred in this TMA transaction (may involve multiple issues)
  102. static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
  103. static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
  104. static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
  105. // Host side kernel arguments
  106. struct Arguments {
  107. Element const* ptr_Q;
  108. typename Seqlen_traits::LayoutT layout_Q;
  109. Element const* ptr_K;
  110. typename Seqlen_traits::LayoutT layout_K;
  111. Element const* ptr_V;
  112. typename Seqlen_traits::LayoutT layout_V;
  113. float const softmax_scale_log2;
  114. };
  115. // Device side kernel params
  116. struct Params {
  117. typename Seqlen_traits::LayoutT layout_Q;
  118. typename Seqlen_traits::LayoutT layout_K;
  119. typename Seqlen_traits::LayoutT layout_V;
  120. cutlass::FastDivmod qhead_per_khead_divmod;
  121. TMA_Q tma_load_Q;
  122. TMA_KV tma_load_K;
  123. TMA_V tma_load_V;
  124. float const softmax_scale_log2;
  125. };
  126. static Params
  127. to_underlying_arguments(Arguments const& args) {
  128. Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
  129. TMA_Q tma_load_Q = make_tma_copy(
  130. GmemTiledCopyQ{},
  131. mQ,
  132. SmemLayoutQ{},
  133. select<0, 2>(TileShape_MNK{}),
  134. _1{}); // no mcast for Q
  135. Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
  136. TMA_KV tma_load_K = make_tma_copy(
  137. GmemTiledCopyKV{},
  138. mK,
  139. SmemLayoutK{}(_, _, _0{}),
  140. select<1, 2>(TileShape_MNK{}),
  141. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  142. auto gmemLayoutVFp16 = args.shape_K;
  143. auto gmemLayoutVFp8 = select<1, 0, 2, 3>(gmemLayoutVFp16);
  144. auto gmemLayoutV = cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(gmemLayoutVFp8, gmemLayoutVFp16);
  145. Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), gmemLayoutV, args.layout_V.stride());
  146. TMA_V tma_load_V = make_tma_copy(
  147. GmemTiledCopyKV{},
  148. mV,
  149. SmemLayoutV{}(_, _, _0{}),
  150. cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(select<2, 1>(TileShape_MNK{}), select<1, 2>(TileShape_MNK{})),
  151. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  152. return {args.layout_Q, args.layout_K, args.layout_V,
  153. cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
  154. tma_load_Q, tma_load_K, tma_load_V,
  155. args.softmax_scale_log2};
  156. }
  157. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  158. CUTLASS_DEVICE
  159. static void prefetch_tma_descriptors(Params const& mainloop_params) {
  160. cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
  161. cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
  162. cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
  163. }
  164. CUTLASS_DEVICE
  165. int get_n_block_max(
  166. Params const& mainloop_params, int m_block,
  167. const Seqlen_traits& seqlen_traits_q,
  168. const Seqlen_traits& seqlen_traits_k
  169. ) {
  170. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  171. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  172. int const seqlen_q = seqlen_traits_q.actual_seq_len;
  173. int const seqlen_k = seqlen_traits_k.actual_seq_len;
  174. int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
  175. if constexpr (Is_causal) {
  176. n_block_max = std::min(n_block_max,
  177. cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q, kBlockN));
  178. }
  179. return n_block_max;
  180. }
  181. template <typename Scheduler, typename SharedStorage>
  182. CUTLASS_DEVICE void
  183. load(Params const& mainloop_params,
  184. MainloopPipeline pipeline_k,
  185. MainloopPipeline pipeline_v,
  186. PipelineState& smem_pipe_write_k,
  187. PipelineState& smem_pipe_write_v,
  188. SharedStorage &shared_storage,
  189. Scheduler& scheduler,
  190. typename Scheduler::Params const& scheduler_params,
  191. typename Scheduler::WorkTileInfo& work_tile_info,
  192. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  193. int work_idx,
  194. const Seqlen_traits& seqlen_traits_q,
  195. const Seqlen_traits& seqlen_traits_k
  196. ) {
  197. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  198. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  199. Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
  200. Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
  201. Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
  202. auto gmemLayoutVFp16 = mainloop_params.shape_K;
  203. auto gmemLayoutVFp8 = select<1, 0, 2, 3>(gmemLayoutVFp16);
  204. auto gmemLayoutV = cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(gmemLayoutVFp8, gmemLayoutVFp16);
  205. Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(gmemLayoutV);
  206. auto [m_block, bidh, bidb] = block_coord;
  207. int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
  208. // Prepare the TMA loads
  209. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  210. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  211. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  212. Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  213. Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  214. Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), TileShapeV{}, cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(make_coord(_0{}, _), make_coord(_, _0{}))); // (N, K, _)
  215. #if 0
  216. if (threadIdx.x == 0 && blockIdx.x == 0) {
  217. print ("\n");
  218. print (gV);
  219. print ("\n");
  220. print (gK);
  221. print ("\n");
  222. print ("\n");
  223. print (sV);
  224. print ("\n");
  225. print (sK);
  226. print ("\n");
  227. print (gmemLayoutVFp8);
  228. print ("\n");
  229. print (gmemLayoutVFp16);
  230. }
  231. // Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
  232. // mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K)
  233. // Tensor gK = seqlen_traits_k.get_local_tile_tensor(
  234. // mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
  235. // Tensor gV = seqlen_traits_k.get_local_tile_tensor(
  236. // mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
  237. Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
  238. Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
  239. auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
  240. group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
  241. auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
  242. group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE)
  243. auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
  244. group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE)
  245. uint16_t mcast_mask_kv = 0;
  246. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
  247. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  248. for (int m = 0; m < size<0>(block_layout); ++m) {
  249. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  250. }
  251. }
  252. int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
  253. int n_block = n_block_max - 1;
  254. int lane_predicate = cute::elect_one_sync();
  255. if (lane_predicate) {
  256. pipeline_k.producer_acquire(smem_pipe_write_k);
  257. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
  258. tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
  259. ++smem_pipe_write_k;
  260. }
  261. // Wait for the MMA warpgroups to say that smem_q is ready
  262. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  263. if (lane_predicate) {
  264. shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  265. copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
  266. }
  267. // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem
  268. // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the
  269. // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O.
  270. shared_storage.barrier_O.wait((work_idx + 1) % 2);
  271. if (lane_predicate) {
  272. // CUTLASS_PRAGMA_NO_UNROLL
  273. #pragma unroll 2
  274. for (; n_block > 0; --n_block) {
  275. pipeline_k.producer_acquire(smem_pipe_write_k);
  276. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
  277. tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
  278. ++smem_pipe_write_k;
  279. pipeline_v.producer_acquire(smem_pipe_write_v);
  280. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
  281. tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
  282. ++smem_pipe_write_v;
  283. }
  284. }
  285. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  286. if (lane_predicate) {
  287. pipeline_v.producer_acquire(smem_pipe_write_v);
  288. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
  289. tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
  290. ++smem_pipe_write_v;
  291. }
  292. scheduler.broadcast_next_work(work_tile_info);
  293. }
  294. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  295. CUTLASS_DEVICE void
  296. load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
  297. PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) {
  298. int lane_predicate = cute::elect_one_sync();
  299. // Issue the epilogue waits
  300. if (lane_predicate) {
  301. /* This helps avoid early exit of blocks in Cluster
  302. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  303. * then would just be acquired since the phase was still inverted from make_producer_start_state
  304. */
  305. pipeline_k.producer_tail(smem_pipe_write_k);
  306. pipeline_v.producer_tail(smem_pipe_write_v);
  307. }
  308. }
  309. CUTLASS_DEVICE void
  310. warp_scheduler_barrier_sync() {
  311. if constexpr (UseSchedulerBarrier) {
  312. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
  313. }
  314. }
  315. CUTLASS_DEVICE void
  316. warp_scheduler_barrier_arrive() {
  317. if constexpr (!UseSchedulerBarrier) { return; }
  318. static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
  319. if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
  320. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
  321. } else {
  322. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
  323. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
  324. }
  325. }
  326. CUTLASS_DEVICE void
  327. mma_init() {
  328. // Tell producer (warp 0) that smem_q is ready
  329. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  330. if constexpr (!UseSchedulerBarrier) { return; }
  331. static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
  332. if (cutlass::canonical_warp_group_idx() > 1) {
  333. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
  334. }
  335. if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
  336. if (cutlass::canonical_warp_group_idx() > 2) {
  337. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
  338. }
  339. }
  340. }
  341. template <typename SharedStorage, typename FrgTensorO, typename Softmax>
  342. CUTLASS_DEVICE void
  343. mma(Params const& mainloop_params,
  344. MainloopPipeline pipeline_k,
  345. MainloopPipeline pipeline_v,
  346. PipelineState& smem_pipe_read_k,
  347. PipelineState& smem_pipe_read_v,
  348. FrgTensorO& tOrO,
  349. Softmax& softmax,
  350. int n_block_count,
  351. int thread_idx,
  352. int work_idx,
  353. int m_block,
  354. SharedStorage& shared_storage,
  355. const Seqlen_traits& seqlen_traits_q,
  356. const Seqlen_traits& seqlen_traits_k
  357. ) {
  358. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  359. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  360. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  361. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  362. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  363. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
  364. typename Ktraits::TiledMma0 tiled_mma0;
  365. typename Ktraits::TiledMma1 tiled_mma1;
  366. auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
  367. auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
  368. // Allocate "fragments/descriptors" for first matmul.
  369. Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
  370. Tensor tSrK = threadMma0.partition_fragment_B(sK);
  371. // Allocate "fragments/descriptors" for second matmul.
  372. // Note: S becomes P.
  373. Tensor tOrV = threadMma1.partition_fragment_B(sVt);
  374. // Dummy sS to just get the shape correctly for GEMM-II.
  375. Tensor sS = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutS{});
  376. Tensor tOrS = threadMma1.partition_fragment_A(sS);
  377. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  378. ReorgCFp8toAFp8 reg2reg;
  379. auto tOrPLayout = ReshapeTStoTP()(tSrS, tOrS);
  380. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  381. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  382. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  383. };
  384. tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
  385. int const seqlen_q = seqlen_traits_q.actual_seq_len;
  386. int const seqlen_k = seqlen_traits_k.actual_seq_len;
  387. int n_block = n_block_count - 1;
  388. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
  389. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
  390. consumer_wait(pipeline_k, smem_pipe_read_k);
  391. warp_scheduler_barrier_sync();
  392. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  393. warp_scheduler_barrier_arrive();
  394. if (work_idx != 0) {
  395. int lane_predicate = cute::elect_one_sync();
  396. if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
  397. tma_store_wait<0>();
  398. #pragma unroll
  399. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  400. shared_storage.barrier_O.arrive(cta_id, lane_predicate);
  401. }
  402. }
  403. }
  404. warpgroup_wait<0>();
  405. pipeline_k.consumer_release(smem_pipe_read_k);
  406. ++smem_pipe_read_k;
  407. auto col_limit_causal = [&](int row, int n_block) {
  408. return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
  409. };
  410. {
  411. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  412. Tensor tScS = threadMma0.partition_C(cS);
  413. #pragma unroll
  414. for (int i = 0; i < size(tSrS); ++i) {
  415. if constexpr (!Is_causal) { // Just masking based on col
  416. if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
  417. } else { // mask based on both row and col
  418. // using std::min is faster than doing col >= limit0 or col >= limit1
  419. // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the
  420. // right hand side can be negative and might be converted to a very large unsigned integer.
  421. if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN,
  422. col_limit_causal(int(get<0>(tScS(i))), n_block))) {
  423. tSrS(i) = -INFINITY;
  424. }
  425. }
  426. }
  427. }
  428. softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
  429. auto tSrSPrec = convert_type<Element>(tSrS);
  430. if constexpr (is_same_v<Element, cutlass::float_e4m3_t>) {
  431. reg2reg(tSrSPrec);
  432. }
  433. Tensor tOrP = make_tensor(tSrSPrec.data(), tOrPLayout);
  434. Tensor scores_scale = make_fragment_like(softmax.row_max);
  435. clear(scores_scale);
  436. constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
  437. // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal
  438. #pragma unroll
  439. for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) {
  440. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  441. consumer_wait(pipeline_k, smem_pipe_read_k);
  442. warp_scheduler_barrier_sync();
  443. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  444. if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); }
  445. consumer_wait(pipeline_v, smem_pipe_read_v);
  446. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  447. warp_scheduler_barrier_arrive();
  448. warpgroup_wait<1>();
  449. pipeline_k.consumer_release(smem_pipe_read_k); // release K
  450. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  451. Tensor tScS = threadMma0.partition_C(cS);
  452. #pragma unroll
  453. for (int i = 0; i < size(tSrS); ++i) {
  454. if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block - 1)) {
  455. tSrS(i) = -INFINITY;
  456. }
  457. }
  458. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
  459. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2);
  460. warpgroup_wait<0>();
  461. pipeline_v.consumer_release(smem_pipe_read_v); // release V
  462. ++smem_pipe_read_k;
  463. ++smem_pipe_read_v;
  464. auto tSrSPrec = convert_type<Element>(tSrS);
  465. if constexpr (is_same_v<Element, cutlass::float_e4m3_t>) {
  466. reg2reg(tSrSPrec);
  467. }
  468. cute::copy(make_tensor(tSrSPrec.data(), tOrPLayout), tOrP);
  469. }
  470. #pragma unroll 1
  471. for (; n_block > 0; --n_block) {
  472. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  473. consumer_wait(pipeline_k, smem_pipe_read_k);
  474. warp_scheduler_barrier_sync();
  475. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  476. softmax.rescale_o(tOrO, scores_scale);
  477. consumer_wait(pipeline_v, smem_pipe_read_v);
  478. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  479. warp_scheduler_barrier_arrive();
  480. warpgroup_wait<1>();
  481. pipeline_k.consumer_release(smem_pipe_read_k); // release K
  482. // auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
  483. cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
  484. softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
  485. warpgroup_wait<0>();
  486. pipeline_v.consumer_release(smem_pipe_read_v); // release V
  487. ++smem_pipe_read_k;
  488. ++smem_pipe_read_v;
  489. // softmax.rescale_o(tOrO, scores_scale);
  490. auto tSrSPrec = convert_type<Element>(tSrS);
  491. if constexpr (is_same_v<Element, cutlass::float_e4m3_t>) {
  492. reg2reg(tSrSPrec);
  493. }
  494. cute::copy(make_tensor(tSrSPrec.data(), tOrPLayout), tOrP);
  495. }
  496. // Tell warp 0 that smem_q is ready
  497. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  498. softmax.rescale_o(tOrO, scores_scale);
  499. consumer_wait(pipeline_v, smem_pipe_read_v);
  500. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  501. cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
  502. warpgroup_wait<0>();
  503. pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang
  504. ++smem_pipe_read_v;
  505. softmax.rescale_o(tOrO, scores_scale);
  506. return;
  507. }
  508. };
  509. } // namespace flash