mainloop_fwd_sm90_tma_gmma_ws.hpp 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include "cutlass/pipeline/pipeline.hpp"
  10. #include "cute/tensor.hpp"
  11. #include "cutlass/gemm/collective/collective_builder.hpp"
  12. #include "named_barrier.hpp"
  13. #include "utils.h"
  14. namespace flash {
  15. using namespace cute;
  16. // 4 warps
  17. struct SmemTransposeFp8_64x64 {
  18. using Element = cutlass::float_e4m3_t;
  19. using ldsm_thread_shape = Shape<_4, _1, _8, _4>;
  20. using ldsm_value_shape = Shape<_2, _8, _2, _1>;
  21. using ldsm_value_stride = Stride<_2, _4, _1, _0>;
  22. using TiledCopyLDSM = decltype(make_tiled_copy(
  23. Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<ldsm_thread_shape>{},
  24. Layout<ldsm_value_shape, ldsm_value_stride>{}));
  25. TiledCopyLDSM tiled_copy_ldsm;
  26. using stsm_thread_shape = Shape<_4, _1, _8, _4>;
  27. // using stsm_thread_stride = Stride<_1, _0, _4, _32>;
  28. #ifndef NO_FP8_COLUMN_PERMUTE
  29. using stsm_value_shape = Shape<_4, _4, _1, _2>;
  30. using stsm_value_stride = Stride<_1, _8, _0, _4>;
  31. #else
  32. using stsm_value_shape = Shape<_4, _4, _2, _1>;
  33. using stsm_value_stride = Stride<_1, _8, _4, _0>;
  34. #endif
  35. using TiledCopySTSM =
  36. decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{},
  37. Layout<stsm_thread_shape>{},
  38. Layout<stsm_value_shape, stsm_value_stride>{}));
  39. TiledCopySTSM tiled_copy_stsm;
  40. template <class SmemTensor, class SmemTensorOut>
  41. CUTLASS_DEVICE void operator()(SmemTensor &&s_in, SmemTensorOut &&s_out) {
  42. using namespace cute;
  43. auto tid = threadIdx.x;
  44. auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
  45. auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);
  46. auto tXsX = thr_copy_ldsm.partition_S(s_in);
  47. auto tXrX = make_tensor<Element>(shape(tXsX));
  48. auto tXsX_out = thr_copy_stsm.partition_D(s_out);
  49. cute::copy(tiled_copy_ldsm, tXsX, tXrX);
  50. auto data = tXrX.data();
  51. // size(tXrX) == 32
  52. CUTLASS_PRAGMA_UNROLL
  53. for (int n = 0; n < size(tXrX); n += 8) {
  54. uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
  55. auto upper = data_32bit[0];
  56. auto lower = data_32bit[1];
  57. data_32bit[0] = __byte_perm(upper, lower, 0x6420);
  58. data_32bit[1] = __byte_perm(upper, lower, 0x7531);
  59. }
  60. cute::copy(tiled_copy_stsm, tXrX, tXsX_out);
  61. }
  62. };
  63. template <typename Ktraits, bool Is_causal, typename Seqlen_traits>
  64. struct CollectiveMainloopFwd {
  65. using Element = typename Ktraits::Element;
  66. using TileShape_MNK = typename Ktraits::TileShape_MNK;
  67. using ClusterShape = typename Ktraits::ClusterShape_MNK;
  68. static constexpr int kStages = Ktraits::kStages;
  69. static constexpr int kHeadDim = Ktraits::kHeadDim;
  70. using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
  71. using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
  72. using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
  73. using SmemLayoutK = typename Ktraits::SmemLayoutK;
  74. using SmemLayoutV = typename Ktraits::SmemLayoutV;
  75. using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
  76. using TMA_Q = decltype(make_tma_copy(
  77. GmemTiledCopyQ{},
  78. make_tensor(
  79. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  80. repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
  81. typename Seqlen_traits::StrideT{}
  82. ),
  83. SmemLayoutQ{},
  84. select<0, 2>(TileShape_MNK{}),
  85. _1{})); // no mcast for Q
  86. using TMA_K = decltype(make_tma_copy(
  87. GmemTiledCopyKV{},
  88. make_tensor(
  89. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  90. repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
  91. typename Seqlen_traits::StrideT{}
  92. ),
  93. take<0, 2>(SmemLayoutK{}),
  94. select<1, 2>(TileShape_MNK{}),
  95. size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
  96. // TMA_V may differ from TMA_K for fp8 kernel (e.g. swizzling mode)
  97. using TMA_V = decltype(make_tma_copy(
  98. GmemTiledCopyKV{},
  99. make_tensor(
  100. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  101. repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
  102. typename Seqlen_traits::StrideT{}
  103. ),
  104. take<0, 2>(SmemLayoutV{}),
  105. select<1, 2>(TileShape_MNK{}),
  106. size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
  107. static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
  108. using MainloopPipeline = typename Ktraits::MainloopPipeline;
  109. using MainloopPipelineNoTMA = typename Ktraits::MainloopPipelineNoTMA;
  110. using PipelineParams = typename MainloopPipeline::Params;
  111. using PipelineState = typename MainloopPipeline::PipelineState;
  112. // Set the bytes transferred in this TMA transaction (may involve multiple issues)
  113. static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
  114. static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
  115. // static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
  116. static constexpr bool UseSchedulerBarrier =
  117. cutlass::sizeof_bits_v<Element> == 8 ? kHeadDim >= 128
  118. : kHeadDim <= 128;
  119. // Host side kernel arguments
  120. struct Arguments {
  121. Element const* ptr_Q;
  122. typename Seqlen_traits::LayoutT layout_Q;
  123. Element const* ptr_K;
  124. typename Seqlen_traits::LayoutT layout_K;
  125. Element const* ptr_V;
  126. typename Seqlen_traits::LayoutT layout_V;
  127. float const softmax_scale_log2;
  128. float const* descale_q_ptr;
  129. float const* descale_k_ptr;
  130. float const* descale_v_ptr;
  131. };
  132. // Device side kernel params
  133. struct Params {
  134. typename Seqlen_traits::LayoutT layout_Q;
  135. typename Seqlen_traits::LayoutT layout_K;
  136. typename Seqlen_traits::LayoutT layout_V;
  137. cutlass::FastDivmod qhead_per_khead_divmod;
  138. TMA_Q tma_load_Q;
  139. TMA_K tma_load_K;
  140. TMA_V tma_load_V;
  141. float const softmax_scale_log2;
  142. float const* descale_q_ptr;
  143. float const* descale_k_ptr;
  144. float const* descale_v_ptr;
  145. };
  146. static Params
  147. to_underlying_arguments(Arguments const& args) {
  148. Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
  149. TMA_Q tma_load_Q = make_tma_copy(
  150. GmemTiledCopyQ{},
  151. mQ,
  152. SmemLayoutQ{},
  153. select<0, 2>(TileShape_MNK{}),
  154. _1{}); // no mcast for Q
  155. Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
  156. TMA_K tma_load_K = make_tma_copy(
  157. GmemTiledCopyKV{},
  158. mK,
  159. SmemLayoutK{}(_, _, _0{}),
  160. select<1, 2>(TileShape_MNK{}),
  161. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  162. Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
  163. TMA_V tma_load_V = make_tma_copy(
  164. GmemTiledCopyKV{},
  165. mV,
  166. SmemLayoutV{}(_, _, _0{}),
  167. select<1, 2>(TileShape_MNK{}),
  168. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  169. return {args.layout_Q, args.layout_K, args.layout_V,
  170. cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
  171. tma_load_Q, tma_load_K, tma_load_V,
  172. args.softmax_scale_log2,
  173. args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr};
  174. }
  175. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  176. CUTLASS_DEVICE
  177. static void prefetch_tma_descriptors(Params const& mainloop_params) {
  178. cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
  179. cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
  180. cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
  181. }
  182. CUTLASS_DEVICE
  183. int get_n_block_max(
  184. Params const& mainloop_params, int m_block,
  185. const Seqlen_traits& seqlen_traits_q,
  186. const Seqlen_traits& seqlen_traits_k
  187. ) {
  188. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  189. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  190. int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
  191. int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);
  192. int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
  193. if constexpr (Is_causal) {
  194. n_block_max = std::min(n_block_max,
  195. cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q, kBlockN));
  196. }
  197. return n_block_max;
  198. }
  199. template <typename Scheduler, typename SharedStorage>
  200. CUTLASS_DEVICE void
  201. load(Params const& mainloop_params,
  202. MainloopPipeline pipeline_k,
  203. MainloopPipeline pipeline_v,
  204. PipelineState& smem_pipe_write_k,
  205. PipelineState& smem_pipe_write_v,
  206. SharedStorage &shared_storage,
  207. Scheduler& scheduler,
  208. typename Scheduler::Params const& scheduler_params,
  209. typename Scheduler::WorkTileInfo& work_tile_info,
  210. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  211. int work_idx,
  212. const Seqlen_traits& seqlen_traits_q,
  213. const Seqlen_traits& seqlen_traits_k
  214. ) {
  215. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  216. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  217. Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
  218. Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
  219. Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
  220. Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
  221. auto [m_block, bidh, bidb] = block_coord;
  222. int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
  223. // Prepare the TMA loads
  224. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  225. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  226. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  227. Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
  228. mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K)
  229. Tensor gK = seqlen_traits_k.get_local_tile_tensor(
  230. mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
  231. Tensor gV = seqlen_traits_k.get_local_tile_tensor(
  232. mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
  233. Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
  234. Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
  235. auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
  236. group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
  237. auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
  238. group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE)
  239. auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
  240. group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE)
  241. uint16_t mcast_mask_kv = 0;
  242. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
  243. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  244. for (int m = 0; m < size<0>(block_layout); ++m) {
  245. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  246. }
  247. }
  248. int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
  249. int n_block = n_block_max - 1;
  250. int lane_predicate = cute::elect_one_sync();
  251. if (lane_predicate) {
  252. pipeline_k.producer_acquire(smem_pipe_write_k);
  253. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
  254. tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
  255. ++smem_pipe_write_k;
  256. }
  257. // Wait for the MMA warpgroups to say that smem_q is ready
  258. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  259. if (lane_predicate) {
  260. shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  261. copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
  262. }
  263. // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem
  264. // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the
  265. // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O.
  266. shared_storage.barrier_O.wait((work_idx + 1) % 2);
  267. if (lane_predicate) {
  268. // CUTLASS_PRAGMA_NO_UNROLL
  269. #pragma unroll 2
  270. for (; n_block > 0; --n_block) {
  271. pipeline_k.producer_acquire(smem_pipe_write_k);
  272. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
  273. tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
  274. ++smem_pipe_write_k;
  275. pipeline_v.producer_acquire(smem_pipe_write_v);
  276. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
  277. tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
  278. ++smem_pipe_write_v;
  279. }
  280. }
  281. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  282. if (lane_predicate) {
  283. pipeline_v.producer_acquire(smem_pipe_write_v);
  284. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
  285. tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
  286. ++smem_pipe_write_v;
  287. }
  288. scheduler.broadcast_next_work(work_tile_info);
  289. }
  290. template <typename Scheduler, typename SharedStorage>
  291. CUTLASS_DEVICE void
  292. load_fp8(Params const& mainloop_params,
  293. MainloopPipeline pipeline_k,
  294. MainloopPipeline pipeline_v,
  295. MainloopPipelineNoTMA pipeline_vt,
  296. PipelineState& smem_pipe_write,
  297. PipelineState& smem_pipe_read,
  298. SharedStorage &shared_storage,
  299. Scheduler& scheduler,
  300. typename Scheduler::Params const& scheduler_params,
  301. typename Scheduler::WorkTileInfo& work_tile_info,
  302. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  303. int work_idx,
  304. const Seqlen_traits& seqlen_traits_q,
  305. const Seqlen_traits& seqlen_traits_k
  306. ) {
  307. using SmemLayoutTransposeV = typename Ktraits::SmemLayoutTransposeV;
  308. using SmemLayoutTransposeVt = typename Ktraits::SmemLayoutTransposeVt;
  309. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  310. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  311. Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
  312. Tensor sV_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutTransposeV{}));
  313. Tensor sVt_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutTransposeVt{}));
  314. auto smem_transpose_V = SmemTransposeFp8_64x64();
  315. auto do_transpose_V = [&](int stage) {
  316. CUTLASS_PRAGMA_UNROLL
  317. for (int j = 0; j < shape<2>(SmemLayoutTransposeV{}); ++j) {
  318. CUTLASS_PRAGMA_UNROLL
  319. for (int i = 0; i < shape<1>(SmemLayoutTransposeV{}); ++i) {
  320. smem_transpose_V(flatten(sV_divide(_, i, j, stage)),
  321. flatten(sVt_divide(_, i, j, stage)));
  322. }
  323. }
  324. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
  325. };
  326. Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
  327. Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
  328. Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
  329. auto [m_block, bidh, bidb] = block_coord;
  330. int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
  331. // Prepare the TMA loads
  332. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  333. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  334. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  335. Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
  336. mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K)
  337. Tensor gK = seqlen_traits_k.get_local_tile_tensor(
  338. mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
  339. Tensor gV = seqlen_traits_k.get_local_tile_tensor(
  340. mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
  341. Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
  342. Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
  343. auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
  344. group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
  345. auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
  346. group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE)
  347. auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
  348. group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE)
  349. uint16_t mcast_mask_kv = 0;
  350. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
  351. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  352. for (int m = 0; m < size<0>(block_layout); ++m) {
  353. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  354. }
  355. }
  356. int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
  357. int n_block = n_block_max - 1;
  358. int lane_predicate = cute::elect_one_sync();
  359. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  360. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  361. pipeline_k.producer_acquire(smem_pipe_write);
  362. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  363. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  364. }
  365. // Wait for the MMA warpgroups to say that smem_q is ready
  366. // for fp8, change from NumThreadsPerWarp to NumThreadsPerWarpGroup
  367. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  368. if constexpr(Is_causal) {
  369. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  370. shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  371. copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
  372. pipeline_v.producer_acquire(smem_pipe_write);
  373. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  374. tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
  375. }
  376. shared_storage.barrier_O.wait((work_idx + 1) % 2);
  377. CUTLASS_PRAGMA_UNROLL
  378. for (int iter = 0; iter < kStages && n_block > 0; ++iter, --n_block) {
  379. pipeline_v.consumer_wait(smem_pipe_read);
  380. // pipeline_vt.producer_acquire(smem_pipe_write);
  381. do_transpose_V(smem_pipe_read.index());
  382. pipeline_vt.producer_commit(smem_pipe_write);
  383. pipeline_v.consumer_release(smem_pipe_read);
  384. ++smem_pipe_write;
  385. ++smem_pipe_read;
  386. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  387. pipeline_k.producer_acquire(smem_pipe_write);
  388. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  389. tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index()));
  390. pipeline_v.producer_acquire(smem_pipe_write);
  391. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  392. tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index()));
  393. }
  394. }
  395. #pragma unroll 2
  396. for (; n_block > 0; --n_block) {
  397. pipeline_v.consumer_wait(smem_pipe_read);
  398. pipeline_vt.producer_acquire(smem_pipe_write);
  399. do_transpose_V(smem_pipe_read.index());
  400. pipeline_vt.producer_commit(smem_pipe_write);
  401. pipeline_v.consumer_release(smem_pipe_read);
  402. ++smem_pipe_write;
  403. ++smem_pipe_read;
  404. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  405. pipeline_k.producer_acquire(smem_pipe_write);
  406. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  407. tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index()));
  408. pipeline_v.producer_acquire(smem_pipe_write);
  409. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  410. tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index()));
  411. }
  412. }
  413. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  414. scheduler.broadcast_next_work(work_tile_info);
  415. pipeline_v.consumer_wait(smem_pipe_read);
  416. if (n_block_max > kStages)
  417. pipeline_vt.producer_acquire(smem_pipe_write);
  418. do_transpose_V(smem_pipe_read.index());
  419. pipeline_vt.producer_commit(smem_pipe_write);
  420. pipeline_v.consumer_release(smem_pipe_read);
  421. ++smem_pipe_write;
  422. ++smem_pipe_read;
  423. } else {
  424. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  425. shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  426. copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
  427. pipeline_v.producer_acquire(smem_pipe_write);
  428. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  429. tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
  430. }
  431. // With fp8 kernel, smem_o is in union with smem_v_out,
  432. // so could use NamedBarrier instead of ClusterBarrier.
  433. // But, this doesn't appear to have any benefit.
  434. shared_storage.barrier_O.wait((work_idx + 1) % 2);
  435. pipeline_v.consumer_wait(smem_pipe_read);
  436. // pipeline_vt.producer_acquire(smem_pipe_write);
  437. do_transpose_V(smem_pipe_read.index());
  438. pipeline_vt.producer_commit(smem_pipe_write);
  439. pipeline_v.consumer_release(smem_pipe_read);
  440. ++smem_pipe_write;
  441. ++smem_pipe_read;
  442. --n_block;
  443. constexpr int extra_iterations = kStages - 1;
  444. CUTLASS_PRAGMA_UNROLL
  445. for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter) {
  446. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  447. pipeline_k.producer_acquire(smem_pipe_write);
  448. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  449. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  450. pipeline_v.producer_acquire(smem_pipe_write);
  451. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  452. tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
  453. }
  454. pipeline_v.consumer_wait(smem_pipe_read);
  455. // pipeline_vt.producer_acquire(smem_pipe_write);
  456. do_transpose_V(smem_pipe_read.index());
  457. pipeline_vt.producer_commit(smem_pipe_write);
  458. pipeline_v.consumer_release(smem_pipe_read);
  459. ++smem_pipe_write;
  460. ++smem_pipe_read;
  461. --n_block;
  462. }
  463. // CUTLASS_PRAGMA_NO_UNROLL
  464. #pragma unroll 2
  465. for (; n_block >= 0; --n_block) {
  466. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  467. pipeline_k.producer_acquire(smem_pipe_write);
  468. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  469. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  470. pipeline_v.producer_acquire(smem_pipe_write);
  471. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  472. tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
  473. }
  474. pipeline_v.consumer_wait(smem_pipe_read);
  475. pipeline_vt.producer_acquire(smem_pipe_write);
  476. do_transpose_V(smem_pipe_read.index());
  477. pipeline_vt.producer_commit(smem_pipe_write);
  478. pipeline_v.consumer_release(smem_pipe_read);
  479. ++smem_pipe_write;
  480. ++smem_pipe_read;
  481. }
  482. // scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  483. // scheduler.broadcast_next_work(work_tile_info);
  484. }
  485. }
  486. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  487. CUTLASS_DEVICE void
  488. load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
  489. PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) {
  490. int lane_predicate = cute::elect_one_sync();
  491. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  492. // Issue the epilogue waits
  493. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  494. /* This helps avoid early exit of blocks in Cluster
  495. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  496. * then would just be acquired since the phase was still inverted from make_producer_start_state
  497. */
  498. pipeline_k.producer_tail(smem_pipe_write_k);
  499. pipeline_v.producer_tail(smem_pipe_write_v);
  500. }
  501. }
  502. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  503. CUTLASS_DEVICE void
  504. load_tail_one_write(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
  505. PipelineState& smem_pipe_write) {
  506. int lane_predicate = cute::elect_one_sync();
  507. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  508. // Issue the epilogue waits
  509. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  510. /* This helps avoid early exit of blocks in Cluster
  511. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  512. * then would just be acquired since the phase was still inverted from make_producer_start_state
  513. */
  514. pipeline_k.producer_tail(smem_pipe_write);
  515. pipeline_v.producer_tail(smem_pipe_write);
  516. }
  517. }
  518. CUTLASS_DEVICE void
  519. warp_scheduler_barrier_sync() {
  520. if constexpr (UseSchedulerBarrier) {
  521. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
  522. }
  523. }
  524. CUTLASS_DEVICE void
  525. warp_scheduler_barrier_arrive() {
  526. if constexpr (!UseSchedulerBarrier) { return; }
  527. static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
  528. if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
  529. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
  530. } else {
  531. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
  532. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
  533. }
  534. }
  535. CUTLASS_DEVICE void
  536. mma_init() {
  537. // Tell producer (warp 0) that smem_q is ready
  538. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + Ktraits::NumProducerThreads, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  539. if constexpr (!UseSchedulerBarrier) { return; }
  540. static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
  541. if (cutlass::canonical_warp_group_idx() > 1) {
  542. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
  543. }
  544. if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
  545. if (cutlass::canonical_warp_group_idx() > 2) {
  546. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
  547. }
  548. }
  549. }
  550. template <typename SharedStorage, typename FrgTensorO, typename Softmax>
  551. CUTLASS_DEVICE void
  552. mma(Params const& mainloop_params,
  553. MainloopPipeline pipeline_k,
  554. MainloopPipeline pipeline_v,
  555. PipelineState& smem_pipe_read_k,
  556. PipelineState& smem_pipe_read_v,
  557. FrgTensorO& tOrO,
  558. Softmax& softmax,
  559. int n_block_count,
  560. int thread_idx,
  561. int work_idx,
  562. int m_block,
  563. SharedStorage& shared_storage,
  564. const Seqlen_traits& seqlen_traits_q,
  565. const Seqlen_traits& seqlen_traits_k
  566. ) {
  567. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  568. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  569. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  570. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  571. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  572. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
  573. typename Ktraits::TiledMma0 tiled_mma0;
  574. typename Ktraits::TiledMma1 tiled_mma1;
  575. auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
  576. auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
  577. // Allocate "fragments/descriptors" for first matmul.
  578. Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
  579. Tensor tSrK = threadMma0.partition_fragment_B(sK);
  580. // Allocate "fragments/descriptors" for second matmul.
  581. // Note: S becomes P.
  582. Tensor tOrV = threadMma1.partition_fragment_B(sVt);
  583. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  584. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  585. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  586. };
  587. tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
  588. int const seqlen_q = seqlen_traits_q.actual_seq_len;
  589. int const seqlen_k = seqlen_traits_k.actual_seq_len;
  590. int n_block = n_block_count - 1;
  591. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
  592. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
  593. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  594. consumer_wait(pipeline_k, smem_pipe_read_k);
  595. warp_scheduler_barrier_sync();
  596. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  597. warp_scheduler_barrier_arrive();
  598. if (work_idx != 0) {
  599. int lane_predicate = cute::elect_one_sync();
  600. if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
  601. tma_store_wait<0>();
  602. #pragma unroll
  603. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  604. shared_storage.barrier_O.arrive(cta_id, lane_predicate);
  605. }
  606. }
  607. }
  608. warpgroup_wait<0>();
  609. pipeline_k.consumer_release(smem_pipe_read_k);
  610. ++smem_pipe_read_k;
  611. auto col_limit_causal = [&](int row, int n_block) {
  612. return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
  613. };
  614. {
  615. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  616. Tensor tScS = threadMma0.partition_C(cS);
  617. #pragma unroll
  618. for (int i = 0; i < size(tSrS); ++i) {
  619. if constexpr (!Is_causal) { // Just masking based on col
  620. if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
  621. } else { // mask based on both row and col
  622. // using std::min is faster than doing col >= limit0 or col >= limit1
  623. // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the
  624. // right hand side can be negative and might be converted to a very large unsigned integer.
  625. if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN,
  626. col_limit_causal(int(get<0>(tScS(i))), n_block))) {
  627. tSrS(i) = -INFINITY;
  628. }
  629. }
  630. }
  631. }
  632. softmax.template online_softmax</*Is_first=*/true>(tSrS);
  633. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
  634. Tensor scores_scale = make_fragment_like(softmax.row_max);
  635. clear(scores_scale);
  636. constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
  637. // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal
  638. #pragma unroll
  639. for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) {
  640. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  641. consumer_wait(pipeline_k, smem_pipe_read_k);
  642. warp_scheduler_barrier_sync();
  643. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  644. if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); }
  645. consumer_wait(pipeline_v, smem_pipe_read_v);
  646. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  647. warp_scheduler_barrier_arrive();
  648. warpgroup_wait<1>();
  649. pipeline_k.consumer_release(smem_pipe_read_k); // release K
  650. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  651. Tensor tScS = threadMma0.partition_C(cS);
  652. #pragma unroll
  653. for (int i = 0; i < size(tSrS); ++i) {
  654. if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block - 1)) {
  655. tSrS(i) = -INFINITY;
  656. }
  657. }
  658. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS), scores_scale);
  659. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS);
  660. warpgroup_wait<0>();
  661. pipeline_v.consumer_release(smem_pipe_read_v); // release V
  662. ++smem_pipe_read_k;
  663. ++smem_pipe_read_v;
  664. cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
  665. }
  666. #pragma unroll 1
  667. for (; n_block > 0; --n_block) {
  668. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  669. consumer_wait(pipeline_k, smem_pipe_read_k);
  670. warp_scheduler_barrier_sync();
  671. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  672. softmax.rescale_o(tOrO, scores_scale);
  673. consumer_wait(pipeline_v, smem_pipe_read_v);
  674. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  675. warp_scheduler_barrier_arrive();
  676. warpgroup_wait<1>();
  677. pipeline_k.consumer_release(smem_pipe_read_k); // release K
  678. // auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
  679. cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
  680. softmax.template online_softmax</*Is_first=*/false>(tSrS);
  681. warpgroup_wait<0>();
  682. pipeline_v.consumer_release(smem_pipe_read_v); // release V
  683. ++smem_pipe_read_k;
  684. ++smem_pipe_read_v;
  685. // softmax.rescale_o(tOrO, scores_scale);
  686. cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
  687. }
  688. // Tell warp 0 that smem_q is ready
  689. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  690. softmax.rescale_o(tOrO, scores_scale);
  691. consumer_wait(pipeline_v, smem_pipe_read_v);
  692. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  693. cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS), scores_scale);
  694. warpgroup_wait<0>();
  695. pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang
  696. ++smem_pipe_read_v;
  697. softmax.rescale_o(tOrO, scores_scale);
  698. return;
  699. }
  700. template <bool Delay_V_release = false, typename SharedStorage, typename FrgTensorO, typename Softmax>
  701. CUTLASS_DEVICE void
  702. mma_fp8(Params const& mainloop_params,
  703. MainloopPipeline pipeline_k,
  704. MainloopPipelineNoTMA pipeline_vt,
  705. PipelineState& smem_pipe_read,
  706. PipelineState& smem_pipe_release,
  707. FrgTensorO& tOrO,
  708. Softmax& softmax,
  709. int n_block_count,
  710. int thread_idx,
  711. int work_idx,
  712. int m_block,
  713. SharedStorage& shared_storage,
  714. const Seqlen_traits& seqlen_traits_q,
  715. const Seqlen_traits& seqlen_traits_k
  716. ) {
  717. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  718. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  719. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  720. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  721. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  722. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutVt{});
  723. typename Ktraits::TiledMma0 tiled_mma0;
  724. typename Ktraits::TiledMma1 tiled_mma1;
  725. auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
  726. auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
  727. // Allocate "fragments/descriptors" for first matmul.
  728. Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
  729. Tensor tSrK = threadMma0.partition_fragment_B(sK);
  730. // Allocate "fragments/descriptors" for second matmul.
  731. Tensor tOrV = threadMma1.partition_fragment_B(sVt);
  732. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  733. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  734. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  735. };
  736. tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
  737. // workaround for fp8 only perf regression pending change to seqlen traits class
  738. int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
  739. int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);
  740. int n_block = n_block_count - 1;
  741. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
  742. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
  743. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  744. consumer_wait(pipeline_k, smem_pipe_read);
  745. warp_scheduler_barrier_sync();
  746. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  747. if (work_idx != 0) {
  748. int lane_predicate = cute::elect_one_sync();
  749. if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
  750. tma_store_wait<0>();
  751. #pragma unroll
  752. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  753. shared_storage.barrier_O.arrive(cta_id, lane_predicate);
  754. }
  755. }
  756. }
  757. warpgroup_wait<0>();
  758. warp_scheduler_barrier_arrive();
  759. pipeline_k.consumer_release(smem_pipe_read);
  760. auto col_limit_causal = [&](int row, int n_block) {
  761. return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
  762. };
  763. {
  764. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  765. Tensor tScS = threadMma0.partition_C(cS);
  766. #pragma unroll
  767. for (int i = 0; i < size(tSrS); ++i) {
  768. if constexpr (!Is_causal) { // Just masking based on col
  769. if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
  770. } else { // mask based on both row and col
  771. if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN,
  772. col_limit_causal(int(get<0>(tScS(i))), n_block))) {
  773. tSrS(i) = -INFINITY;
  774. }
  775. }
  776. }
  777. }
  778. softmax.template online_softmax</*Is_first=*/true>(tSrS);
  779. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  780. permute_regs_A_to_C(tOrP);
  781. Tensor scores_scale = make_fragment_like(softmax.row_max);
  782. clear(scores_scale);
  783. consumer_wait(pipeline_vt, smem_pipe_read);
  784. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  785. if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
  786. ++smem_pipe_read;
  787. --n_block;
  788. constexpr int extra_iterations = !Is_causal ? kStages - 1 : cute::ceil_div(kBlockM, kBlockN);
  789. if constexpr(Is_causal) {
  790. CUTLASS_PRAGMA_UNROLL
  791. for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) {
  792. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  793. consumer_wait(pipeline_k, smem_pipe_read);
  794. warp_scheduler_barrier_sync();
  795. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  796. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  797. Tensor tScS = threadMma0.partition_C(cS);
  798. #pragma unroll
  799. for (int i = 0; i < size(tSrS); ++i) {
  800. if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block)) {
  801. tSrS(i) = -INFINITY;
  802. }
  803. }
  804. warp_scheduler_barrier_arrive();
  805. pipeline_k.consumer_release(smem_pipe_read);
  806. if constexpr(Delay_V_release) {
  807. pipeline_vt.consumer_release(smem_pipe_release);
  808. ++smem_pipe_release;
  809. }
  810. consumer_wait(pipeline_vt, smem_pipe_read);
  811. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS), scores_scale);
  812. softmax.rescale_o(tOrO, scores_scale);
  813. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS);
  814. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  815. permute_regs_A_to_C(tOrP);
  816. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  817. if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
  818. ++smem_pipe_read;
  819. }
  820. } else {
  821. CUTLASS_PRAGMA_UNROLL
  822. for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) {
  823. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  824. consumer_wait(pipeline_k, smem_pipe_read);
  825. if constexpr(Delay_V_release) {
  826. pipeline_vt.consumer_release(smem_pipe_release);
  827. ++smem_pipe_release;
  828. }
  829. warp_scheduler_barrier_sync();
  830. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  831. warp_scheduler_barrier_arrive();
  832. if constexpr(!Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
  833. else { consumer_wait(pipeline_vt, smem_pipe_read); }
  834. cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
  835. softmax.rescale_o(tOrO, scores_scale);
  836. softmax.template online_softmax</*Is_first=*/false>(tSrS);
  837. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  838. permute_regs_A_to_C(tOrP);
  839. if constexpr (Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
  840. else { consumer_wait(pipeline_vt, smem_pipe_read); }
  841. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  842. if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
  843. ++smem_pipe_read;
  844. }
  845. }
  846. if constexpr(Delay_V_release) {
  847. warp_scheduler_barrier_sync();
  848. CUTLASS_PRAGMA_NO_UNROLL
  849. for (; n_block >= 0; --n_block) {
  850. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  851. consumer_wait(pipeline_k, smem_pipe_read);
  852. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  853. warp_scheduler_barrier_arrive();
  854. pipeline_k.consumer_release(smem_pipe_read);
  855. pipeline_vt.consumer_release(smem_pipe_release);
  856. cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
  857. softmax.rescale_o(tOrO, scores_scale);
  858. softmax.template online_softmax</*Is_first=*/false>(tSrS);
  859. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  860. permute_regs_A_to_C(tOrP);
  861. consumer_wait(pipeline_vt, smem_pipe_read);
  862. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  863. warp_scheduler_barrier_sync();
  864. ++smem_pipe_read;
  865. ++smem_pipe_release;
  866. }
  867. warp_scheduler_barrier_arrive();
  868. pipeline_vt.consumer_release(smem_pipe_release);
  869. ++smem_pipe_release;
  870. } else {
  871. if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
  872. CUTLASS_PRAGMA_NO_UNROLL
  873. for (; n_block >= 0; --n_block) {
  874. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  875. consumer_wait(pipeline_k, smem_pipe_read);
  876. if constexpr (kHeadDim == 256) { warp_scheduler_barrier_sync(); }
  877. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  878. warp_scheduler_barrier_arrive();
  879. pipeline_k.consumer_release(smem_pipe_read);
  880. cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
  881. softmax.rescale_o(tOrO, scores_scale);
  882. softmax.template online_softmax</*Is_first=*/false>(tSrS);
  883. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  884. permute_regs_A_to_C(tOrP);
  885. consumer_wait(pipeline_vt, smem_pipe_read);
  886. if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
  887. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  888. pipeline_vt.consumer_release(smem_pipe_read);
  889. ++smem_pipe_read;
  890. }
  891. if constexpr (kHeadDim == 128) { warp_scheduler_barrier_arrive(); }
  892. }
  893. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  894. cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, shared_storage.descale_v), scores_scale);
  895. softmax.rescale_o(tOrO, scores_scale);
  896. return;
  897. }
  898. };
  899. } // namespace flash