mainloop_fwd_sm90_tma_gmma_ws.hpp 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include "cutlass/pipeline/pipeline.hpp"
  10. #include "cute/tensor.hpp"
  11. #include "cutlass/gemm/collective/collective_builder.hpp"
  12. #include "named_barrier.hpp"
  13. #include "utils.h"
  14. namespace flash {
  15. using namespace cute;
  16. // 4 warps
  17. struct SmemTransposeFp8_64x64 {
  18. using Element = cutlass::float_e4m3_t;
  19. using ldsm_thread_shape = Shape<_4, _1, _8, _4>;
  20. using ldsm_value_shape = Shape<_2, _8, _2, _1>;
  21. using ldsm_value_stride = Stride<_2, _4, _1, _0>;
  22. using TiledCopyLDSM = decltype(make_tiled_copy(
  23. Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<ldsm_thread_shape>{},
  24. Layout<ldsm_value_shape, ldsm_value_stride>{}));
  25. TiledCopyLDSM tiled_copy_ldsm;
  26. using stsm_thread_shape = Shape<_4, _1, _8, _4>;
  27. // using stsm_thread_stride = Stride<_1, _0, _4, _32>;
  28. #ifndef NO_FP8_COLUMN_PERMUTE
  29. using stsm_value_shape = Shape<_4, _4, _1, _2>;
  30. using stsm_value_stride = Stride<_1, _8, _0, _4>;
  31. #else
  32. using stsm_value_shape = Shape<_4, _4, _2, _1>;
  33. using stsm_value_stride = Stride<_1, _8, _4, _0>;
  34. #endif
  35. using TiledCopySTSM =
  36. decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{},
  37. Layout<stsm_thread_shape>{},
  38. Layout<stsm_value_shape, stsm_value_stride>{}));
  39. TiledCopySTSM tiled_copy_stsm;
  40. template <class SmemTensor, class SmemTensorOut>
  41. CUTLASS_DEVICE void operator()(SmemTensor &&s_in, SmemTensorOut &&s_out) {
  42. using namespace cute;
  43. auto tid = threadIdx.x;
  44. auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
  45. auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);
  46. auto tXsX = thr_copy_ldsm.partition_S(s_in);
  47. auto tXrX = make_tensor<Element>(shape(tXsX));
  48. auto tXsX_out = thr_copy_stsm.partition_D(s_out);
  49. cute::copy(tiled_copy_ldsm, tXsX, tXrX);
  50. auto data = tXrX.data();
  51. // size(tXrX) == 32
  52. CUTLASS_PRAGMA_UNROLL
  53. for (int n = 0; n < size(tXrX); n += 8) {
  54. uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
  55. auto upper = data_32bit[0];
  56. auto lower = data_32bit[1];
  57. data_32bit[0] = __byte_perm(upper, lower, 0x6420);
  58. data_32bit[1] = __byte_perm(upper, lower, 0x7531);
  59. }
  60. cute::copy(tiled_copy_stsm, tXrX, tXsX_out);
  61. }
  62. };
  63. template <typename Ktraits, bool Is_causal, bool Is_local, typename Seqlen_traits>
  64. struct CollectiveMainloopFwd {
  65. using Element = typename Ktraits::Element;
  66. using TileShape_MNK = typename Ktraits::TileShape_MNK;
  67. using ClusterShape = typename Ktraits::ClusterShape_MNK;
  68. static constexpr int kStages = Ktraits::kStages;
  69. static constexpr int kHeadDim = Ktraits::kHeadDim;
  70. using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
  71. using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
  72. using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
  73. using SmemLayoutK = typename Ktraits::SmemLayoutK;
  74. using SmemLayoutV = typename Ktraits::SmemLayoutV;
  75. using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
  76. using TMA_Q = decltype(make_tma_copy(
  77. GmemTiledCopyQ{},
  78. make_tensor(
  79. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  80. repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
  81. typename Seqlen_traits::StrideT{}
  82. ),
  83. SmemLayoutQ{},
  84. select<0, 2>(TileShape_MNK{}),
  85. _1{})); // no mcast for Q
  86. using TMA_K = decltype(make_tma_copy(
  87. GmemTiledCopyKV{},
  88. make_tensor(
  89. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  90. repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
  91. typename Seqlen_traits::StrideT{}
  92. ),
  93. take<0, 2>(SmemLayoutK{}),
  94. select<1, 2>(TileShape_MNK{}),
  95. size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
  96. // TMA_V may differ from TMA_K for fp8 kernel (e.g. swizzling mode)
  97. using TMA_V = decltype(make_tma_copy(
  98. GmemTiledCopyKV{},
  99. make_tensor(
  100. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  101. repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
  102. typename Seqlen_traits::StrideT{}
  103. ),
  104. take<0, 2>(SmemLayoutV{}),
  105. select<1, 2>(TileShape_MNK{}),
  106. size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
  107. static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
  108. using MainloopPipeline = typename Ktraits::MainloopPipeline;
  109. using MainloopPipelineNoTMA = typename Ktraits::MainloopPipelineNoTMA;
  110. using PipelineParams = typename MainloopPipeline::Params;
  111. using PipelineState = typename MainloopPipeline::PipelineState;
  112. // Set the bytes transferred in this TMA transaction (may involve multiple issues)
  113. static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
  114. static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
  115. // static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
  116. static constexpr bool UseSchedulerBarrier =
  117. cutlass::sizeof_bits_v<Element> == 8 ? kHeadDim >= 128
  118. : kHeadDim <= 128;
  119. // Host side kernel arguments
  120. struct Arguments {
  121. Element const* ptr_Q;
  122. typename Seqlen_traits::LayoutT layout_Q;
  123. Element const* ptr_K;
  124. typename Seqlen_traits::LayoutT layout_K;
  125. Element const* ptr_V;
  126. typename Seqlen_traits::LayoutT layout_V;
  127. float const softmax_scale_log2;
  128. float const* descale_q_ptr;
  129. float const* descale_k_ptr;
  130. float const* descale_v_ptr;
  131. int window_size_left;
  132. int window_size_right;
  133. };
  134. // Device side kernel params
  135. struct Params {
  136. typename Seqlen_traits::LayoutT layout_Q;
  137. typename Seqlen_traits::LayoutT layout_K;
  138. typename Seqlen_traits::LayoutT layout_V;
  139. cutlass::FastDivmod qhead_per_khead_divmod;
  140. TMA_Q tma_load_Q;
  141. TMA_K tma_load_K;
  142. TMA_V tma_load_V;
  143. float const softmax_scale_log2;
  144. float const* descale_q_ptr;
  145. float const* descale_k_ptr;
  146. float const* descale_v_ptr;
  147. int window_size_left;
  148. int window_size_right;
  149. };
  150. static Params
  151. to_underlying_arguments(Arguments const& args) {
  152. Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
  153. TMA_Q tma_load_Q = make_tma_copy(
  154. GmemTiledCopyQ{},
  155. mQ,
  156. SmemLayoutQ{},
  157. select<0, 2>(TileShape_MNK{}),
  158. _1{}); // no mcast for Q
  159. Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
  160. TMA_K tma_load_K = make_tma_copy(
  161. GmemTiledCopyKV{},
  162. mK,
  163. SmemLayoutK{}(_, _, _0{}),
  164. select<1, 2>(TileShape_MNK{}),
  165. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  166. Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
  167. TMA_V tma_load_V = make_tma_copy(
  168. GmemTiledCopyKV{},
  169. mV,
  170. SmemLayoutV{}(_, _, _0{}),
  171. select<1, 2>(TileShape_MNK{}),
  172. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  173. return {args.layout_Q, args.layout_K, args.layout_V,
  174. cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
  175. tma_load_Q, tma_load_K, tma_load_V,
  176. args.softmax_scale_log2,
  177. args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr,
  178. args.window_size_left, args.window_size_right};
  179. }
  180. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  181. CUTLASS_DEVICE
  182. static void prefetch_tma_descriptors(Params const& mainloop_params) {
  183. cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
  184. cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
  185. cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
  186. }
  187. CUTLASS_DEVICE
  188. int get_n_block_max(
  189. Params const& mainloop_params, int m_block,
  190. const Seqlen_traits& seqlen_traits_q,
  191. const Seqlen_traits& seqlen_traits_k
  192. ) {
  193. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  194. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  195. int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
  196. int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);
  197. int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
  198. if constexpr (Is_causal || Is_local) {
  199. n_block_max = std::min(
  200. n_block_max,
  201. cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q + mainloop_params.window_size_right, kBlockN));
  202. }
  203. return n_block_max;
  204. }
  205. CUTLASS_DEVICE
  206. int get_n_block_min(
  207. Params const& mainloop_params, int m_block,
  208. const Seqlen_traits& seqlen_traits_q,
  209. const Seqlen_traits& seqlen_traits_k
  210. ) {
  211. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  212. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  213. int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
  214. int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);
  215. if constexpr (!Is_local) {
  216. return 0;
  217. } else {
  218. return std::max(
  219. 0,
  220. (m_block * kBlockM + seqlen_k - seqlen_q - mainloop_params.window_size_left) / kBlockN
  221. );
  222. }
  223. }
  224. template <typename Scheduler, typename SharedStorage>
  225. CUTLASS_DEVICE void
  226. load(Params const& mainloop_params,
  227. MainloopPipeline pipeline_k,
  228. MainloopPipeline pipeline_v,
  229. PipelineState& smem_pipe_write_k,
  230. PipelineState& smem_pipe_write_v,
  231. SharedStorage &shared_storage,
  232. Scheduler& scheduler,
  233. typename Scheduler::Params const& scheduler_params,
  234. typename Scheduler::WorkTileInfo& work_tile_info,
  235. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  236. int work_idx,
  237. const Seqlen_traits& seqlen_traits_q,
  238. const Seqlen_traits& seqlen_traits_k
  239. ) {
  240. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  241. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  242. Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
  243. Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
  244. Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
  245. Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
  246. auto [m_block, bidh, bidb] = block_coord;
  247. int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
  248. // Prepare the TMA loads
  249. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  250. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  251. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  252. Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
  253. mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K)
  254. Tensor gK = seqlen_traits_k.get_local_tile_tensor(
  255. mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
  256. Tensor gV = seqlen_traits_k.get_local_tile_tensor(
  257. mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
  258. Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
  259. Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
  260. auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
  261. group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
  262. auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
  263. group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE)
  264. auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
  265. group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE)
  266. uint16_t mcast_mask_kv = 0;
  267. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
  268. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  269. for (int m = 0; m < size<0>(block_layout); ++m) {
  270. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  271. }
  272. }
  273. const int n_block_min = get_n_block_min(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
  274. const int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
  275. int n_block = n_block_max - 1;
  276. int lane_predicate = cute::elect_one_sync();
  277. if (lane_predicate) {
  278. pipeline_k.producer_acquire(smem_pipe_write_k);
  279. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
  280. tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
  281. ++smem_pipe_write_k;
  282. }
  283. // Wait for the MMA warpgroups to say that smem_q is ready
  284. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  285. if (lane_predicate) {
  286. shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  287. copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
  288. }
  289. // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem
  290. // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the
  291. // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O.
  292. shared_storage.barrier_O.wait((work_idx + 1) % 2);
  293. if (lane_predicate) {
  294. // CUTLASS_PRAGMA_NO_UNROLL
  295. #pragma unroll 2
  296. for (; n_block > n_block_min; --n_block) {
  297. pipeline_k.producer_acquire(smem_pipe_write_k);
  298. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
  299. tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
  300. ++smem_pipe_write_k;
  301. pipeline_v.producer_acquire(smem_pipe_write_v);
  302. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
  303. tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
  304. ++smem_pipe_write_v;
  305. }
  306. }
  307. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  308. if (lane_predicate) {
  309. pipeline_v.producer_acquire(smem_pipe_write_v);
  310. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
  311. tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
  312. ++smem_pipe_write_v;
  313. }
  314. scheduler.broadcast_next_work(work_tile_info);
  315. }
  316. template <typename Scheduler, typename SharedStorage>
  317. CUTLASS_DEVICE void
  318. load_fp8(Params const& mainloop_params,
  319. MainloopPipeline pipeline_k,
  320. MainloopPipeline pipeline_v,
  321. MainloopPipelineNoTMA pipeline_vt,
  322. PipelineState& smem_pipe_write,
  323. PipelineState& smem_pipe_read,
  324. SharedStorage &shared_storage,
  325. Scheduler& scheduler,
  326. typename Scheduler::Params const& scheduler_params,
  327. typename Scheduler::WorkTileInfo& work_tile_info,
  328. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  329. int work_idx,
  330. const Seqlen_traits& seqlen_traits_q,
  331. const Seqlen_traits& seqlen_traits_k
  332. ) {
  333. using SmemLayoutTransposeV = typename Ktraits::SmemLayoutTransposeV;
  334. using SmemLayoutTransposeVt = typename Ktraits::SmemLayoutTransposeVt;
  335. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  336. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  337. Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
  338. Tensor sV_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutTransposeV{}));
  339. Tensor sVt_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutTransposeVt{}));
  340. auto smem_transpose_V = SmemTransposeFp8_64x64();
  341. auto do_transpose_V = [&](int stage) {
  342. CUTLASS_PRAGMA_UNROLL
  343. for (int j = 0; j < shape<2>(SmemLayoutTransposeV{}); ++j) {
  344. CUTLASS_PRAGMA_UNROLL
  345. for (int i = 0; i < shape<1>(SmemLayoutTransposeV{}); ++i) {
  346. smem_transpose_V(flatten(sV_divide(_, i, j, stage)),
  347. flatten(sVt_divide(_, i, j, stage)));
  348. }
  349. }
  350. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
  351. };
  352. Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
  353. Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
  354. Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
  355. auto [m_block, bidh, bidb] = block_coord;
  356. int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
  357. // Prepare the TMA loads
  358. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  359. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  360. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  361. Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
  362. mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K)
  363. Tensor gK = seqlen_traits_k.get_local_tile_tensor(
  364. mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
  365. Tensor gV = seqlen_traits_k.get_local_tile_tensor(
  366. mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
  367. Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
  368. Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
  369. auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
  370. group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
  371. auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
  372. group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE)
  373. auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
  374. group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE)
  375. uint16_t mcast_mask_kv = 0;
  376. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
  377. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  378. for (int m = 0; m < size<0>(block_layout); ++m) {
  379. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  380. }
  381. }
  382. int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
  383. int n_block = n_block_max - 1;
  384. int lane_predicate = cute::elect_one_sync();
  385. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  386. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  387. pipeline_k.producer_acquire(smem_pipe_write);
  388. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  389. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  390. }
  391. // Wait for the MMA warpgroups to say that smem_q is ready
  392. // for fp8, change from NumThreadsPerWarp to NumThreadsPerWarpGroup
  393. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  394. if constexpr(Is_causal) {
  395. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  396. shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  397. copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
  398. pipeline_v.producer_acquire(smem_pipe_write);
  399. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  400. tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
  401. }
  402. shared_storage.barrier_O.wait((work_idx + 1) % 2);
  403. CUTLASS_PRAGMA_UNROLL
  404. for (int iter = 0; iter < kStages && n_block > 0; ++iter, --n_block) {
  405. pipeline_v.consumer_wait(smem_pipe_read);
  406. // pipeline_vt.producer_acquire(smem_pipe_write);
  407. do_transpose_V(smem_pipe_read.index());
  408. pipeline_vt.producer_commit(smem_pipe_write);
  409. pipeline_v.consumer_release(smem_pipe_read);
  410. ++smem_pipe_write;
  411. ++smem_pipe_read;
  412. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  413. pipeline_k.producer_acquire(smem_pipe_write);
  414. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  415. tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index()));
  416. pipeline_v.producer_acquire(smem_pipe_write);
  417. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  418. tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index()));
  419. }
  420. }
  421. #pragma unroll 2
  422. for (; n_block > 0; --n_block) {
  423. pipeline_v.consumer_wait(smem_pipe_read);
  424. pipeline_vt.producer_acquire(smem_pipe_write);
  425. do_transpose_V(smem_pipe_read.index());
  426. pipeline_vt.producer_commit(smem_pipe_write);
  427. pipeline_v.consumer_release(smem_pipe_read);
  428. ++smem_pipe_write;
  429. ++smem_pipe_read;
  430. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  431. pipeline_k.producer_acquire(smem_pipe_write);
  432. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  433. tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index()));
  434. pipeline_v.producer_acquire(smem_pipe_write);
  435. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  436. tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index()));
  437. }
  438. }
  439. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  440. scheduler.broadcast_next_work(work_tile_info);
  441. pipeline_v.consumer_wait(smem_pipe_read);
  442. if (n_block_max > kStages)
  443. pipeline_vt.producer_acquire(smem_pipe_write);
  444. do_transpose_V(smem_pipe_read.index());
  445. pipeline_vt.producer_commit(smem_pipe_write);
  446. pipeline_v.consumer_release(smem_pipe_read);
  447. ++smem_pipe_write;
  448. ++smem_pipe_read;
  449. } else {
  450. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  451. shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  452. copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
  453. pipeline_v.producer_acquire(smem_pipe_write);
  454. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  455. tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
  456. }
  457. // With fp8 kernel, smem_o is in union with smem_v_out,
  458. // so could use NamedBarrier instead of ClusterBarrier.
  459. // But, this doesn't appear to have any benefit.
  460. shared_storage.barrier_O.wait((work_idx + 1) % 2);
  461. pipeline_v.consumer_wait(smem_pipe_read);
  462. // pipeline_vt.producer_acquire(smem_pipe_write);
  463. do_transpose_V(smem_pipe_read.index());
  464. pipeline_vt.producer_commit(smem_pipe_write);
  465. pipeline_v.consumer_release(smem_pipe_read);
  466. ++smem_pipe_write;
  467. ++smem_pipe_read;
  468. --n_block;
  469. constexpr int extra_iterations = kStages - 1;
  470. CUTLASS_PRAGMA_UNROLL
  471. for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter) {
  472. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  473. pipeline_k.producer_acquire(smem_pipe_write);
  474. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  475. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  476. pipeline_v.producer_acquire(smem_pipe_write);
  477. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  478. tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
  479. }
  480. pipeline_v.consumer_wait(smem_pipe_read);
  481. // pipeline_vt.producer_acquire(smem_pipe_write);
  482. do_transpose_V(smem_pipe_read.index());
  483. pipeline_vt.producer_commit(smem_pipe_write);
  484. pipeline_v.consumer_release(smem_pipe_read);
  485. ++smem_pipe_write;
  486. ++smem_pipe_read;
  487. --n_block;
  488. }
  489. // CUTLASS_PRAGMA_NO_UNROLL
  490. #pragma unroll 2
  491. for (; n_block >= 0; --n_block) {
  492. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  493. pipeline_k.producer_acquire(smem_pipe_write);
  494. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  495. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  496. pipeline_v.producer_acquire(smem_pipe_write);
  497. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  498. tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
  499. }
  500. pipeline_v.consumer_wait(smem_pipe_read);
  501. pipeline_vt.producer_acquire(smem_pipe_write);
  502. do_transpose_V(smem_pipe_read.index());
  503. pipeline_vt.producer_commit(smem_pipe_write);
  504. pipeline_v.consumer_release(smem_pipe_read);
  505. ++smem_pipe_write;
  506. ++smem_pipe_read;
  507. }
  508. // scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  509. // scheduler.broadcast_next_work(work_tile_info);
  510. }
  511. }
  512. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  513. CUTLASS_DEVICE void
  514. load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
  515. PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) {
  516. int lane_predicate = cute::elect_one_sync();
  517. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  518. // Issue the epilogue waits
  519. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  520. /* This helps avoid early exit of blocks in Cluster
  521. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  522. * then would just be acquired since the phase was still inverted from make_producer_start_state
  523. */
  524. pipeline_k.producer_tail(smem_pipe_write_k);
  525. pipeline_v.producer_tail(smem_pipe_write_v);
  526. }
  527. }
  528. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  529. CUTLASS_DEVICE void
  530. load_tail_one_write(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
  531. PipelineState& smem_pipe_write) {
  532. int lane_predicate = cute::elect_one_sync();
  533. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  534. // Issue the epilogue waits
  535. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  536. /* This helps avoid early exit of blocks in Cluster
  537. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  538. * then would just be acquired since the phase was still inverted from make_producer_start_state
  539. */
  540. pipeline_k.producer_tail(smem_pipe_write);
  541. pipeline_v.producer_tail(smem_pipe_write);
  542. }
  543. }
  544. CUTLASS_DEVICE void
  545. warp_scheduler_barrier_sync() {
  546. if constexpr (UseSchedulerBarrier) {
  547. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
  548. }
  549. }
  550. CUTLASS_DEVICE void
  551. warp_scheduler_barrier_arrive() {
  552. if constexpr (!UseSchedulerBarrier) { return; }
  553. static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
  554. if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
  555. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
  556. } else {
  557. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
  558. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
  559. }
  560. }
  561. CUTLASS_DEVICE void
  562. mma_init() {
  563. // Tell producer (warp 0) that smem_q is ready
  564. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + Ktraits::NumProducerThreads, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  565. if constexpr (!UseSchedulerBarrier) { return; }
  566. static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
  567. if (cutlass::canonical_warp_group_idx() > 1) {
  568. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
  569. }
  570. if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
  571. if (cutlass::canonical_warp_group_idx() > 2) {
  572. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
  573. }
  574. }
  575. }
  576. template <typename SharedStorage, typename FrgTensorO, typename Softmax>
  577. CUTLASS_DEVICE void
  578. mma(Params const& mainloop_params,
  579. MainloopPipeline pipeline_k,
  580. MainloopPipeline pipeline_v,
  581. PipelineState& smem_pipe_read_k,
  582. PipelineState& smem_pipe_read_v,
  583. FrgTensorO& tOrO,
  584. Softmax& softmax,
  585. int n_block_count,
  586. int n_block_min,
  587. int thread_idx,
  588. int work_idx,
  589. int m_block,
  590. SharedStorage& shared_storage,
  591. const Seqlen_traits& seqlen_traits_q,
  592. const Seqlen_traits& seqlen_traits_k
  593. ) {
  594. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  595. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  596. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  597. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  598. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  599. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
  600. typename Ktraits::TiledMma0 tiled_mma0;
  601. typename Ktraits::TiledMma1 tiled_mma1;
  602. auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
  603. auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
  604. // Allocate "fragments/descriptors" for first matmul.
  605. Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
  606. Tensor tSrK = threadMma0.partition_fragment_B(sK);
  607. // Allocate "fragments/descriptors" for second matmul.
  608. // Note: S becomes P.
  609. Tensor tOrV = threadMma1.partition_fragment_B(sVt);
  610. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  611. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  612. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  613. };
  614. tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
  615. int const seqlen_q = seqlen_traits_q.actual_seq_len;
  616. int const seqlen_k = seqlen_traits_k.actual_seq_len;
  617. int n_block = n_block_count - 1;
  618. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
  619. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
  620. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  621. consumer_wait(pipeline_k, smem_pipe_read_k);
  622. warp_scheduler_barrier_sync();
  623. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  624. warp_scheduler_barrier_arrive();
  625. if (work_idx != 0) {
  626. int lane_predicate = cute::elect_one_sync();
  627. if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
  628. tma_store_wait<0>();
  629. #pragma unroll
  630. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  631. shared_storage.barrier_O.arrive(cta_id, lane_predicate);
  632. }
  633. }
  634. }
  635. warpgroup_wait<0>();
  636. pipeline_k.consumer_release(smem_pipe_read_k);
  637. ++smem_pipe_read_k;
  638. auto col_limit_right = [&](int row, int n_block) {
  639. return std::min(
  640. seqlen_k - n_block * kBlockN,
  641. row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + mainloop_params.window_size_right
  642. );
  643. };
  644. auto col_limit_left = [&](int row, int n_block) {
  645. return std::max(
  646. 0,
  647. row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - mainloop_params.window_size_left
  648. );
  649. };
  650. {
  651. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  652. Tensor tScS = threadMma0.partition_C(cS);
  653. #pragma unroll
  654. for (int i = 0; i < size(tSrS); ++i) {
  655. if constexpr (!Is_causal && !Is_local) { // Just masking based on col
  656. if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
  657. } else { // mask based on both row and col
  658. // using std::min is faster than doing col >= limit0 or col >= limit1
  659. // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the
  660. // right hand side can be negative and might be converted to a very large unsigned integer.
  661. if (int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block)) {
  662. tSrS(i) = -INFINITY;
  663. } else if constexpr (Is_local) {
  664. if (int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block)) {
  665. tSrS(i) = -INFINITY;
  666. }
  667. }
  668. }
  669. }
  670. }
  671. softmax.template online_softmax</*Is_first=*/true>(tSrS);
  672. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
  673. Tensor scores_scale = make_fragment_like(softmax.row_max);
  674. clear(scores_scale);
  675. constexpr int n_masking_steps = (!Is_causal) ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
  676. // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal
  677. #pragma unroll
  678. for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > n_block_min; ++masking_step, --n_block) {
  679. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  680. consumer_wait(pipeline_k, smem_pipe_read_k);
  681. warp_scheduler_barrier_sync();
  682. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  683. if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); }
  684. consumer_wait(pipeline_v, smem_pipe_read_v);
  685. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  686. warp_scheduler_barrier_arrive();
  687. warpgroup_wait<1>();
  688. pipeline_k.consumer_release(smem_pipe_read_k); // release K
  689. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  690. Tensor tScS = threadMma0.partition_C(cS);
  691. #pragma unroll
  692. for (int i = 0; i < size(tSrS); ++i) {
  693. if (int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block - 1)) {
  694. tSrS(i) = -INFINITY;
  695. }
  696. }
  697. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS), scores_scale);
  698. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS);
  699. warpgroup_wait<0>();
  700. pipeline_v.consumer_release(smem_pipe_read_v); // release V
  701. ++smem_pipe_read_k;
  702. ++smem_pipe_read_v;
  703. cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
  704. }
  705. #pragma unroll 1
  706. for (; n_block > n_block_min; --n_block) {
  707. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  708. consumer_wait(pipeline_k, smem_pipe_read_k);
  709. warp_scheduler_barrier_sync();
  710. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  711. softmax.rescale_o(tOrO, scores_scale);
  712. consumer_wait(pipeline_v, smem_pipe_read_v);
  713. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  714. warp_scheduler_barrier_arrive();
  715. warpgroup_wait<1>();
  716. pipeline_k.consumer_release(smem_pipe_read_k); // release K
  717. if constexpr(Is_local) {
  718. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  719. Tensor tScS = threadMma0.partition_C(cS);
  720. #pragma unroll
  721. for (int i = 0; i < size(tSrS); ++i) {
  722. if (
  723. int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block - 1) ||
  724. int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block - 1)
  725. ) {
  726. tSrS(i) = -INFINITY;
  727. }
  728. }
  729. }
  730. // auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
  731. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS), scores_scale);
  732. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS);
  733. warpgroup_wait<0>();
  734. pipeline_v.consumer_release(smem_pipe_read_v); // release V
  735. ++smem_pipe_read_k;
  736. ++smem_pipe_read_v;
  737. // softmax.rescale_o(tOrO, scores_scale);
  738. cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
  739. }
  740. // Tell warp 0 that smem_q is ready
  741. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  742. softmax.rescale_o(tOrO, scores_scale);
  743. consumer_wait(pipeline_v, smem_pipe_read_v);
  744. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  745. cute::copy(softmax.template finalize</*Check_inf=*/Is_causal || Is_local>(tSrS), scores_scale);
  746. warpgroup_wait<0>();
  747. pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang
  748. ++smem_pipe_read_v;
  749. softmax.rescale_o(tOrO, scores_scale);
  750. return;
  751. }
  752. template <bool Delay_V_release = false, typename SharedStorage, typename FrgTensorO, typename Softmax>
  753. CUTLASS_DEVICE void
  754. mma_fp8(Params const& mainloop_params,
  755. MainloopPipeline pipeline_k,
  756. MainloopPipelineNoTMA pipeline_vt,
  757. PipelineState& smem_pipe_read,
  758. PipelineState& smem_pipe_release,
  759. FrgTensorO& tOrO,
  760. Softmax& softmax,
  761. int n_block_count,
  762. int thread_idx,
  763. int work_idx,
  764. int m_block,
  765. SharedStorage& shared_storage,
  766. const Seqlen_traits& seqlen_traits_q,
  767. const Seqlen_traits& seqlen_traits_k
  768. ) {
  769. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  770. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  771. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  772. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  773. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  774. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutVt{});
  775. typename Ktraits::TiledMma0 tiled_mma0;
  776. typename Ktraits::TiledMma1 tiled_mma1;
  777. auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
  778. auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
  779. // Allocate "fragments/descriptors" for first matmul.
  780. Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
  781. Tensor tSrK = threadMma0.partition_fragment_B(sK);
  782. // Allocate "fragments/descriptors" for second matmul.
  783. Tensor tOrV = threadMma1.partition_fragment_B(sVt);
  784. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  785. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  786. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  787. };
  788. tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
  789. // workaround for fp8 only perf regression pending change to seqlen traits class
  790. int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
  791. int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);
  792. int n_block = n_block_count - 1;
  793. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
  794. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
  795. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  796. consumer_wait(pipeline_k, smem_pipe_read);
  797. warp_scheduler_barrier_sync();
  798. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  799. if (work_idx != 0) {
  800. int lane_predicate = cute::elect_one_sync();
  801. if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
  802. tma_store_wait<0>();
  803. #pragma unroll
  804. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  805. shared_storage.barrier_O.arrive(cta_id, lane_predicate);
  806. }
  807. }
  808. }
  809. warpgroup_wait<0>();
  810. warp_scheduler_barrier_arrive();
  811. pipeline_k.consumer_release(smem_pipe_read);
  812. auto col_limit_causal = [&](int row, int n_block) {
  813. return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
  814. };
  815. {
  816. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  817. Tensor tScS = threadMma0.partition_C(cS);
  818. #pragma unroll
  819. for (int i = 0; i < size(tSrS); ++i) {
  820. if constexpr (!Is_causal) { // Just masking based on col
  821. if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
  822. } else { // mask based on both row and col
  823. if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN,
  824. col_limit_causal(int(get<0>(tScS(i))), n_block))) {
  825. tSrS(i) = -INFINITY;
  826. }
  827. }
  828. }
  829. }
  830. softmax.template online_softmax</*Is_first=*/true>(tSrS);
  831. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  832. permute_regs_A_to_C(tOrP);
  833. Tensor scores_scale = make_fragment_like(softmax.row_max);
  834. clear(scores_scale);
  835. consumer_wait(pipeline_vt, smem_pipe_read);
  836. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  837. if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
  838. ++smem_pipe_read;
  839. --n_block;
  840. constexpr int extra_iterations = !Is_causal ? kStages - 1 : cute::ceil_div(kBlockM, kBlockN);
  841. if constexpr(Is_causal) {
  842. CUTLASS_PRAGMA_UNROLL
  843. for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) {
  844. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  845. consumer_wait(pipeline_k, smem_pipe_read);
  846. warp_scheduler_barrier_sync();
  847. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  848. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  849. Tensor tScS = threadMma0.partition_C(cS);
  850. #pragma unroll
  851. for (int i = 0; i < size(tSrS); ++i) {
  852. if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block)) {
  853. tSrS(i) = -INFINITY;
  854. }
  855. }
  856. warp_scheduler_barrier_arrive();
  857. pipeline_k.consumer_release(smem_pipe_read);
  858. if constexpr(Delay_V_release) {
  859. pipeline_vt.consumer_release(smem_pipe_release);
  860. ++smem_pipe_release;
  861. }
  862. consumer_wait(pipeline_vt, smem_pipe_read);
  863. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS), scores_scale);
  864. softmax.rescale_o(tOrO, scores_scale);
  865. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS);
  866. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  867. permute_regs_A_to_C(tOrP);
  868. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  869. if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
  870. ++smem_pipe_read;
  871. }
  872. } else {
  873. CUTLASS_PRAGMA_UNROLL
  874. for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) {
  875. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  876. consumer_wait(pipeline_k, smem_pipe_read);
  877. if constexpr(Delay_V_release) {
  878. pipeline_vt.consumer_release(smem_pipe_release);
  879. ++smem_pipe_release;
  880. }
  881. warp_scheduler_barrier_sync();
  882. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  883. warp_scheduler_barrier_arrive();
  884. if constexpr(!Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
  885. else { consumer_wait(pipeline_vt, smem_pipe_read); }
  886. cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
  887. softmax.rescale_o(tOrO, scores_scale);
  888. softmax.template online_softmax</*Is_first=*/false>(tSrS);
  889. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  890. permute_regs_A_to_C(tOrP);
  891. if constexpr (Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
  892. else { consumer_wait(pipeline_vt, smem_pipe_read); }
  893. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  894. if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
  895. ++smem_pipe_read;
  896. }
  897. }
  898. if constexpr(Delay_V_release) {
  899. warp_scheduler_barrier_sync();
  900. CUTLASS_PRAGMA_NO_UNROLL
  901. for (; n_block >= 0; --n_block) {
  902. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  903. consumer_wait(pipeline_k, smem_pipe_read);
  904. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  905. warp_scheduler_barrier_arrive();
  906. pipeline_k.consumer_release(smem_pipe_read);
  907. pipeline_vt.consumer_release(smem_pipe_release);
  908. cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
  909. softmax.rescale_o(tOrO, scores_scale);
  910. softmax.template online_softmax</*Is_first=*/false>(tSrS);
  911. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  912. permute_regs_A_to_C(tOrP);
  913. consumer_wait(pipeline_vt, smem_pipe_read);
  914. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  915. warp_scheduler_barrier_sync();
  916. ++smem_pipe_read;
  917. ++smem_pipe_release;
  918. }
  919. warp_scheduler_barrier_arrive();
  920. pipeline_vt.consumer_release(smem_pipe_release);
  921. ++smem_pipe_release;
  922. } else {
  923. if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
  924. CUTLASS_PRAGMA_NO_UNROLL
  925. for (; n_block >= 0; --n_block) {
  926. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  927. consumer_wait(pipeline_k, smem_pipe_read);
  928. if constexpr (kHeadDim == 256) { warp_scheduler_barrier_sync(); }
  929. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  930. warp_scheduler_barrier_arrive();
  931. pipeline_k.consumer_release(smem_pipe_read);
  932. cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
  933. softmax.rescale_o(tOrO, scores_scale);
  934. softmax.template online_softmax</*Is_first=*/false>(tSrS);
  935. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  936. permute_regs_A_to_C(tOrP);
  937. consumer_wait(pipeline_vt, smem_pipe_read);
  938. if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
  939. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  940. pipeline_vt.consumer_release(smem_pipe_read);
  941. ++smem_pipe_read;
  942. }
  943. if constexpr (kHeadDim == 128) { warp_scheduler_barrier_arrive(); }
  944. }
  945. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  946. cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, shared_storage.descale_v), scores_scale);
  947. softmax.rescale_o(tOrO, scores_scale);
  948. return;
  949. }
  950. };
  951. } // namespace flash