mainloop_fwd_sm90_tma_gmma_ws.hpp 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include "cutlass/pipeline/pipeline.hpp"
  10. #include "cute/tensor.hpp"
  11. #include "cutlass/gemm/collective/collective_builder.hpp"
  12. #include "named_barrier.hpp"
  13. #include "utils.h"
  14. namespace flash {
  15. using namespace cute;
  16. // 4 warps
  17. struct SmemTransposeFp8_64x64 {
  18. using Element = cutlass::float_e4m3_t;
  19. using ldsm_thread_shape = Shape<_4, _1, _8, _4>;
  20. using ldsm_value_shape = Shape<_2, _8, _2, _1>;
  21. using ldsm_value_stride = Stride<_2, _4, _1, _0>;
  22. using TiledCopyLDSM = decltype(make_tiled_copy(
  23. Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<ldsm_thread_shape>{},
  24. Layout<ldsm_value_shape, ldsm_value_stride>{}));
  25. TiledCopyLDSM tiled_copy_ldsm;
  26. using stsm_thread_shape = Shape<_4, _1, _8, _4>;
  27. // using stsm_thread_stride = Stride<_1, _0, _4, _32>;
  28. #ifndef NO_FP8_COLUMN_PERMUTE
  29. using stsm_value_shape = Shape<_4, _4, _1, _2>;
  30. using stsm_value_stride = Stride<_1, _8, _0, _4>;
  31. #else
  32. using stsm_value_shape = Shape<_4, _4, _2, _1>;
  33. using stsm_value_stride = Stride<_1, _8, _4, _0>;
  34. #endif
  35. using TiledCopySTSM =
  36. decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{},
  37. Layout<stsm_thread_shape>{},
  38. Layout<stsm_value_shape, stsm_value_stride>{}));
  39. TiledCopySTSM tiled_copy_stsm;
  40. template <class SmemTensor, class SmemTensorOut>
  41. CUTLASS_DEVICE void operator()(SmemTensor &&s_in, SmemTensorOut &&s_out) {
  42. using namespace cute;
  43. auto tid = threadIdx.x;
  44. auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
  45. auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);
  46. auto tXsX = thr_copy_ldsm.partition_S(s_in);
  47. auto tXrX = make_tensor<Element>(shape(tXsX));
  48. auto tXsX_out = thr_copy_stsm.partition_D(s_out);
  49. cute::copy(tiled_copy_ldsm, tXsX, tXrX);
  50. auto data = tXrX.data();
  51. // size(tXrX) == 32
  52. CUTLASS_PRAGMA_UNROLL
  53. for (int n = 0; n < size(tXrX); n += 8) {
  54. uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
  55. auto upper = data_32bit[0];
  56. auto lower = data_32bit[1];
  57. data_32bit[0] = __byte_perm(upper, lower, 0x6420);
  58. data_32bit[1] = __byte_perm(upper, lower, 0x7531);
  59. }
  60. cute::copy(tiled_copy_stsm, tXrX, tXsX_out);
  61. }
  62. };
  63. template <typename Ktraits, bool Is_causal, bool Is_local, typename Seqlen_traits, typename Seqlen_traits_Q = Seqlen_traits>
  64. struct CollectiveMainloopFwd {
  65. using Element = typename Ktraits::Element;
  66. using TileShape_MNK = typename Ktraits::TileShape_MNK;
  67. using ClusterShape = typename Ktraits::ClusterShape_MNK;
  68. static constexpr int kStages = Ktraits::kStages;
  69. static constexpr int kHeadDim = Ktraits::kHeadDim;
  70. // static constexpr int kBlockM = Ktraits::kBlockM;
  71. // static constexpr int kBlockN = Ktraits::kBlockN;
  72. // static constexpr int kBlockH = Ktraits::kBlockH;
  73. static constexpr bool Is_split = Ktraits::Is_split;
  74. static constexpr bool No_smem_O = Ktraits::No_smem_O;
  75. using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
  76. using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
  77. using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
  78. using SmemLayoutQCopy = typename Ktraits::SmemLayoutQCopy;
  79. using TileShapeQCopy = typename Ktraits::TileShapeQCopy;
  80. using SmemLayoutK = typename Ktraits::SmemLayoutK;
  81. using SmemLayoutV = typename Ktraits::SmemLayoutV;
  82. using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
  83. using TMA_Q = decltype(make_tma_copy(
  84. GmemTiledCopyQ{},
  85. make_tensor(
  86. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  87. repeat_like(typename Seqlen_traits_Q::StrideT{}, int32_t(0)),
  88. typename Seqlen_traits_Q::StrideT{}
  89. ),
  90. SmemLayoutQCopy{},
  91. TileShapeQCopy{},
  92. _1{})); // no mcast for Q
  93. using TMA_K = decltype(make_tma_copy(
  94. GmemTiledCopyKV{},
  95. make_tensor(
  96. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  97. repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
  98. typename Seqlen_traits::StrideT{}
  99. ),
  100. take<0, 2>(SmemLayoutK{}),
  101. select<1, 2>(TileShape_MNK{}),
  102. size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
  103. // TMA_V may differ from TMA_K for fp8 kernel (e.g. swizzling mode)
  104. using TMA_V = decltype(make_tma_copy(
  105. GmemTiledCopyKV{},
  106. make_tensor(
  107. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  108. repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
  109. typename Seqlen_traits::StrideT{}
  110. ),
  111. take<0, 2>(SmemLayoutV{}),
  112. select<1, 2>(TileShape_MNK{}),
  113. size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
  114. static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
  115. using MainloopPipeline = typename Ktraits::MainloopPipeline;
  116. using MainloopPipelineNoTMA = typename Ktraits::MainloopPipelineNoTMA;
  117. using PipelineParams = typename MainloopPipeline::Params;
  118. using PipelineState = typename MainloopPipeline::PipelineState;
  119. // Set the bytes transferred in this TMA transaction (may involve multiple issues)
  120. static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
  121. static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
  122. // static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
  123. static constexpr bool UseSchedulerBarrier = Ktraits::kNWarps >= 12 &&
  124. (cutlass::sizeof_bits_v<Element> == 8 ? kHeadDim >= 128 : kHeadDim <= 128);
  125. // Host side kernel arguments
  126. struct Arguments {
  127. Element const* ptr_Q;
  128. typename Seqlen_traits_Q::LayoutT layout_Q;
  129. Element const* ptr_K;
  130. typename Seqlen_traits::LayoutT layout_K;
  131. Element const* ptr_V;
  132. typename Seqlen_traits::LayoutT layout_V;
  133. float const softmax_scale_log2;
  134. float const* descale_q_ptr;
  135. float const* descale_k_ptr;
  136. float const* descale_v_ptr;
  137. int window_size_left;
  138. int window_size_right;
  139. int const qhead_per_khead;
  140. int const* cache_batch_idx;
  141. int const num_splits;
  142. };
  143. // Device side kernel params
  144. struct Params {
  145. typename Seqlen_traits_Q::LayoutT layout_Q;
  146. typename Seqlen_traits::LayoutT layout_K;
  147. typename Seqlen_traits::LayoutT layout_V;
  148. cutlass::FastDivmod qhead_per_khead_divmod;
  149. TMA_Q tma_load_Q;
  150. TMA_K tma_load_K;
  151. TMA_V tma_load_V;
  152. float const softmax_scale_log2;
  153. float const* descale_q_ptr;
  154. float const* descale_k_ptr;
  155. float const* descale_v_ptr;
  156. int window_size_left;
  157. int window_size_right;
  158. int const* cache_batch_idx;
  159. cutlass::FastDivmod num_splits_divmod;
  160. };
  161. static Params
  162. to_underlying_arguments(Arguments const& args) {
  163. Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
  164. TMA_Q tma_load_Q = make_tma_copy(
  165. GmemTiledCopyQ{},
  166. mQ,
  167. SmemLayoutQCopy{},
  168. TileShapeQCopy{},
  169. _1{}); // no mcast for Q
  170. Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
  171. TMA_K tma_load_K = make_tma_copy(
  172. GmemTiledCopyKV{},
  173. mK,
  174. SmemLayoutK{}(_, _, _0{}),
  175. select<1, 2>(TileShape_MNK{}),
  176. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  177. Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
  178. TMA_V tma_load_V = make_tma_copy(
  179. GmemTiledCopyKV{},
  180. mV,
  181. SmemLayoutV{}(_, _, _0{}),
  182. select<1, 2>(TileShape_MNK{}),
  183. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  184. return {args.layout_Q, args.layout_K, args.layout_V,
  185. cutlass::FastDivmod(args.qhead_per_khead),
  186. tma_load_Q, tma_load_K, tma_load_V,
  187. args.softmax_scale_log2,
  188. args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr,
  189. args.window_size_left, args.window_size_right,
  190. args.cache_batch_idx,
  191. cutlass::FastDivmod(args.num_splits)};
  192. }
  193. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  194. CUTLASS_DEVICE
  195. static void prefetch_tma_descriptors(Params const& mainloop_params) {
  196. cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
  197. cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
  198. cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
  199. }
  200. CUTLASS_DEVICE
  201. void get_n_block_min_max(
  202. Params const& mainloop_params,
  203. int m_block,
  204. int n_split_idx,
  205. const Seqlen_traits_Q& seqlen_traits_q,
  206. const Seqlen_traits& seqlen_traits_k,
  207. int& n_block_min,
  208. int& n_block_max
  209. ) {
  210. // static constexpr int kBlockM = get<0>(TileShape_MNK{});
  211. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  212. static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{})/Ktraits::kBlockH;
  213. int const seqlen_q = seqlen_traits_q.actual_seq_len;
  214. int const seqlen_k = seqlen_traits_k.actual_seq_len;
  215. n_block_max = cute::ceil_div(seqlen_k, kBlockN);
  216. if constexpr(Is_split) {
  217. int const n_blocks_per_split
  218. = mainloop_params.num_splits_divmod.divide(n_block_max + int(mainloop_params.num_splits_divmod) - 1);
  219. n_block_min = n_split_idx * n_blocks_per_split;
  220. n_block_max = std::min(n_block_max, (n_split_idx + 1) * n_blocks_per_split);
  221. }
  222. if constexpr (Is_causal) {
  223. n_block_max = std::min(
  224. n_block_max,
  225. cute::ceil_div((m_block + 1) * kBlockM_div_H + seqlen_k - seqlen_q, kBlockN));
  226. } else if constexpr (Is_local) {
  227. n_block_max = std::min(
  228. n_block_max,
  229. cute::ceil_div((m_block + 1) * kBlockM_div_H + seqlen_k - seqlen_q + mainloop_params.window_size_right, kBlockN));
  230. n_block_min = std::max(
  231. n_block_min,
  232. (m_block * kBlockM_div_H + seqlen_k - seqlen_q - mainloop_params.window_size_left) / kBlockN);
  233. }
  234. }
  235. CUTLASS_DEVICE
  236. void get_n_block_max(
  237. Params const& mainloop_params,
  238. int m_block,
  239. const Seqlen_traits_Q& seqlen_traits_q,
  240. const Seqlen_traits& seqlen_traits_k,
  241. int& n_block_max
  242. ) {
  243. // static constexpr int kBlockM = get<0>(TileShape_MNK{});
  244. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  245. static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{})/Ktraits::kBlockH;
  246. int const seqlen_q = seqlen_traits_q.actual_seq_len;
  247. int const seqlen_k = seqlen_traits_k.actual_seq_len;
  248. n_block_max = cute::ceil_div(seqlen_k, kBlockN);
  249. if constexpr (Is_causal) {
  250. n_block_max = std::min(n_block_max,
  251. cute::ceil_div((m_block + 1) * kBlockM_div_H + seqlen_k - seqlen_q, kBlockN));
  252. }
  253. }
  254. template <typename Scheduler, typename SharedStorage>
  255. CUTLASS_DEVICE void
  256. load(Params const& mainloop_params,
  257. MainloopPipeline pipeline_k,
  258. MainloopPipeline pipeline_v,
  259. PipelineState& smem_pipe_write_k,
  260. PipelineState& smem_pipe_write_v,
  261. SharedStorage &shared_storage,
  262. Scheduler& scheduler,
  263. typename Scheduler::Params const& scheduler_params,
  264. typename Scheduler::WorkTileInfo& work_tile_info,
  265. cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
  266. int work_idx,
  267. const Seqlen_traits_Q& seqlen_traits_q,
  268. const Seqlen_traits& seqlen_traits_k,
  269. int n_block_min,
  270. int n_block_max
  271. ) {
  272. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQCopy{});
  273. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  274. Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
  275. Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
  276. Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
  277. Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
  278. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  279. const int bidb_cache = mainloop_params.cache_batch_idx == nullptr ? bidb : mainloop_params.cache_batch_idx[bidb];
  280. const int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
  281. // Prepare the TMA loads
  282. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  283. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  284. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  285. Tensor gQ = [&] {
  286. // Need this inside lambda to capture structured binding
  287. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  288. if constexpr(Seqlen_traits_Q::UseGQAPacking) {
  289. return seqlen_traits_q.get_local_tile_tensor(
  290. mQ, TileShapeQCopy{}, bidh_kv, bidb)
  291. (_, _, _, m_block, bidh % int(mainloop_params.qhead_per_khead_divmod)); // (M/H, H, K)
  292. } else {
  293. return seqlen_traits_q.get_local_tile_tensor(
  294. mQ, TileShapeQCopy{}, bidh, bidb)(_, _, m_block); // (M, K)
  295. }
  296. }();
  297. Tensor gK = seqlen_traits_k.get_local_tile_tensor(
  298. mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _)
  299. Tensor gV = seqlen_traits_k.get_local_tile_tensor(
  300. mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _)
  301. Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
  302. Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
  303. auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
  304. group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
  305. auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
  306. group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE)
  307. auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
  308. group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE)
  309. uint16_t mcast_mask_kv = 0;
  310. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
  311. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  312. for (int m = 0; m < size<0>(block_layout); ++m) {
  313. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  314. }
  315. }
  316. int n_block = n_block_max - 1;
  317. int lane_predicate = cute::elect_one_sync();
  318. if (lane_predicate) {
  319. pipeline_k.producer_acquire(smem_pipe_write_k);
  320. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
  321. tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
  322. ++smem_pipe_write_k;
  323. }
  324. // Wait for the MMA warpgroups to say that smem_q is ready
  325. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  326. if (lane_predicate) {
  327. shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  328. copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
  329. }
  330. // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem
  331. // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the
  332. // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O.
  333. if constexpr (!No_smem_O) { shared_storage.barrier_O.wait((work_idx + 1) % 2); }
  334. if (lane_predicate) {
  335. // CUTLASS_PRAGMA_NO_UNROLL
  336. #pragma unroll 2
  337. for (; n_block > n_block_min; --n_block) {
  338. pipeline_k.producer_acquire(smem_pipe_write_k);
  339. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
  340. tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
  341. ++smem_pipe_write_k;
  342. pipeline_v.producer_acquire(smem_pipe_write_v);
  343. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
  344. tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
  345. ++smem_pipe_write_v;
  346. }
  347. }
  348. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  349. if (lane_predicate) {
  350. pipeline_v.producer_acquire(smem_pipe_write_v);
  351. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
  352. tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
  353. ++smem_pipe_write_v;
  354. }
  355. scheduler.broadcast_next_work(work_tile_info);
  356. }
  357. template <typename Scheduler, typename SharedStorage>
  358. CUTLASS_DEVICE void
  359. load_fp8(Params const& mainloop_params,
  360. MainloopPipeline pipeline_k,
  361. MainloopPipeline pipeline_v,
  362. MainloopPipelineNoTMA pipeline_vt,
  363. PipelineState& smem_pipe_write,
  364. PipelineState& smem_pipe_read,
  365. SharedStorage &shared_storage,
  366. Scheduler& scheduler,
  367. typename Scheduler::Params const& scheduler_params,
  368. typename Scheduler::WorkTileInfo& work_tile_info,
  369. cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
  370. int work_idx,
  371. const Seqlen_traits_Q& seqlen_traits_q,
  372. const Seqlen_traits& seqlen_traits_k,
  373. int n_block_min,
  374. int n_block_max
  375. ) {
  376. using SmemLayoutTransposeV = typename Ktraits::SmemLayoutTransposeV;
  377. using SmemLayoutTransposeVt = typename Ktraits::SmemLayoutTransposeVt;
  378. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQCopy{});
  379. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  380. Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
  381. Tensor sV_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutTransposeV{}));
  382. Tensor sVt_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutTransposeVt{}));
  383. auto smem_transpose_V = SmemTransposeFp8_64x64();
  384. auto do_transpose_V = [&](int stage) {
  385. CUTLASS_PRAGMA_UNROLL
  386. for (int j = 0; j < shape<2>(SmemLayoutTransposeV{}); ++j) {
  387. CUTLASS_PRAGMA_UNROLL
  388. for (int i = 0; i < shape<1>(SmemLayoutTransposeV{}); ++i) {
  389. smem_transpose_V(flatten(sV_divide(_, i, j, stage)),
  390. flatten(sVt_divide(_, i, j, stage)));
  391. }
  392. }
  393. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
  394. };
  395. Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
  396. Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
  397. Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
  398. auto [m_block, split_idx, bidh, bidb] = block_coord;
  399. const int bidb_cache = mainloop_params.cache_batch_idx == nullptr ? bidb : mainloop_params.cache_batch_idx[bidb];
  400. const int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
  401. // Prepare the TMA loads
  402. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  403. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  404. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  405. Tensor gQ = [&] {
  406. // Need this inside lambda to capture structured binding
  407. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  408. if constexpr(Seqlen_traits_Q::UseGQAPacking) {
  409. return seqlen_traits_q.get_local_tile_tensor(
  410. mQ, TileShapeQCopy{}, bidh_kv, bidb)
  411. (_, _, _, m_block, bidh % int(mainloop_params.qhead_per_khead_divmod)); // (M/H, H, K)
  412. } else {
  413. return seqlen_traits_q.get_local_tile_tensor(
  414. mQ, TileShapeQCopy{}, bidh, bidb)(_, _, m_block); // (M, K)
  415. }
  416. }();
  417. Tensor gK = seqlen_traits_k.get_local_tile_tensor(
  418. mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _)
  419. Tensor gV = seqlen_traits_k.get_local_tile_tensor(
  420. mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _)
  421. Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
  422. Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
  423. auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
  424. group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
  425. auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
  426. group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE)
  427. auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
  428. group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE)
  429. uint16_t mcast_mask_kv = 0;
  430. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
  431. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  432. for (int m = 0; m < size<0>(block_layout); ++m) {
  433. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  434. }
  435. }
  436. int n_block = n_block_max - 1;
  437. int lane_predicate = cute::elect_one_sync();
  438. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  439. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  440. pipeline_k.producer_acquire(smem_pipe_write);
  441. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  442. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  443. }
  444. // Wait for the MMA warpgroups to say that smem_q is ready
  445. // for fp8, change from NumThreadsPerWarp to NumThreadsPerWarpGroup
  446. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  447. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  448. shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  449. copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
  450. if constexpr(!Ktraits::VO_union_all) {
  451. pipeline_v.producer_acquire(smem_pipe_write);
  452. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  453. tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
  454. }
  455. }
  456. // With fp8 kernel, smem_o is in union with smem_v_out,
  457. // except for split kernel + hdim 256,
  458. // so could use NamedBarrier instead of ClusterBarrier.
  459. // But, this doesn't appear to have any benefit.
  460. if constexpr (!No_smem_O) { shared_storage.barrier_O.wait((work_idx + 1) % 2); }
  461. if constexpr(Ktraits::VO_union_all) {
  462. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  463. pipeline_v.producer_acquire(smem_pipe_write);
  464. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  465. tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
  466. }
  467. }
  468. #pragma unroll 2
  469. for (; n_block > n_block_min; --n_block) {
  470. pipeline_v.consumer_wait(smem_pipe_read);
  471. pipeline_vt.producer_acquire(smem_pipe_write);
  472. do_transpose_V(smem_pipe_read.index());
  473. pipeline_vt.producer_commit(smem_pipe_write);
  474. pipeline_v.consumer_release(smem_pipe_read);
  475. ++smem_pipe_write;
  476. ++smem_pipe_read;
  477. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  478. pipeline_k.producer_acquire(smem_pipe_write);
  479. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  480. tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index()));
  481. pipeline_v.producer_acquire(smem_pipe_write);
  482. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
  483. tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index()));
  484. }
  485. }
  486. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  487. scheduler.broadcast_next_work(work_tile_info);
  488. pipeline_v.consumer_wait(smem_pipe_read);
  489. pipeline_vt.producer_acquire(smem_pipe_write);
  490. do_transpose_V(smem_pipe_read.index());
  491. pipeline_vt.producer_commit(smem_pipe_write);
  492. pipeline_v.consumer_release(smem_pipe_read);
  493. ++smem_pipe_write;
  494. ++smem_pipe_read;
  495. }
  496. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  497. CUTLASS_DEVICE void
  498. load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
  499. PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) {
  500. int lane_predicate = cute::elect_one_sync();
  501. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  502. // Issue the epilogue waits
  503. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  504. /* This helps avoid early exit of blocks in Cluster
  505. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  506. * then would just be acquired since the phase was still inverted from make_producer_start_state
  507. */
  508. pipeline_k.producer_tail(smem_pipe_write_k);
  509. pipeline_v.producer_tail(smem_pipe_write_v);
  510. }
  511. }
  512. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  513. CUTLASS_DEVICE void
  514. load_tail_one_write(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
  515. PipelineState& smem_pipe_write) {
  516. int lane_predicate = cute::elect_one_sync();
  517. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  518. // Issue the epilogue waits
  519. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  520. /* This helps avoid early exit of blocks in Cluster
  521. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  522. * then would just be acquired since the phase was still inverted from make_producer_start_state
  523. */
  524. pipeline_k.producer_tail(smem_pipe_write);
  525. pipeline_v.producer_tail(smem_pipe_write);
  526. }
  527. }
  528. CUTLASS_DEVICE void
  529. warp_scheduler_barrier_sync() {
  530. if constexpr (UseSchedulerBarrier) {
  531. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
  532. }
  533. }
  534. CUTLASS_DEVICE void
  535. warp_scheduler_barrier_arrive() {
  536. if constexpr (!UseSchedulerBarrier) {
  537. return;
  538. } else {
  539. static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
  540. if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
  541. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
  542. } else {
  543. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
  544. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
  545. }
  546. }
  547. }
  548. CUTLASS_DEVICE void
  549. mma_init() {
  550. // Tell producer (warp 0) that smem_q is ready
  551. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + Ktraits::NumProducerThreads, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  552. if constexpr (!UseSchedulerBarrier) {
  553. return;
  554. } else {
  555. static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
  556. if (cutlass::canonical_warp_group_idx() > 1) {
  557. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
  558. }
  559. if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
  560. if (cutlass::canonical_warp_group_idx() > 2) {
  561. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
  562. }
  563. }
  564. }
  565. }
  566. template <typename SharedStorage, typename FrgTensorO, typename Softmax>
  567. CUTLASS_DEVICE void
  568. mma(Params const& mainloop_params,
  569. MainloopPipeline pipeline_k,
  570. MainloopPipeline pipeline_v,
  571. PipelineState& smem_pipe_read_k,
  572. PipelineState& smem_pipe_read_v,
  573. FrgTensorO& tOrO,
  574. Softmax& softmax,
  575. int n_block_min,
  576. int n_block_max,
  577. int thread_idx,
  578. int work_idx,
  579. int m_block,
  580. SharedStorage& shared_storage,
  581. const Seqlen_traits_Q& seqlen_traits_q,
  582. const Seqlen_traits& seqlen_traits_k
  583. ) {
  584. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  585. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  586. static constexpr int kBlockH = Ktraits::kBlockH;
  587. static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{}) / kBlockH;
  588. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  589. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  590. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
  591. typename Ktraits::TiledMma0 tiled_mma0;
  592. typename Ktraits::TiledMma1 tiled_mma1;
  593. auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
  594. auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
  595. // Allocate "fragments/descriptors" for first matmul.
  596. Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
  597. Tensor tSrK = threadMma0.partition_fragment_B(sK);
  598. // Allocate "fragments/descriptors" for second matmul.
  599. // Note: S becomes P.
  600. Tensor tOrV = threadMma1.partition_fragment_B(sVt);
  601. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  602. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  603. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  604. };
  605. tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
  606. int const seqlen_q = seqlen_traits_q.actual_seq_len;
  607. int const seqlen_k = seqlen_traits_k.actual_seq_len;
  608. int n_block = n_block_max - 1;
  609. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
  610. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
  611. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  612. consumer_wait(pipeline_k, smem_pipe_read_k);
  613. warp_scheduler_barrier_sync();
  614. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  615. warp_scheduler_barrier_arrive();
  616. if constexpr (!No_smem_O) {
  617. if (work_idx != 0) {
  618. int lane_predicate = cute::elect_one_sync();
  619. if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
  620. tma_store_wait<0>();
  621. #pragma unroll
  622. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  623. shared_storage.barrier_O.arrive(cta_id, lane_predicate);
  624. }
  625. }
  626. }
  627. }
  628. warpgroup_wait<0>();
  629. pipeline_k.consumer_release(smem_pipe_read_k);
  630. ++smem_pipe_read_k;
  631. auto col_limit_right = [&](int row, int n_block) {
  632. int col_limit_base = row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H;
  633. if constexpr(Is_local)
  634. return col_limit_base + mainloop_params.window_size_right;
  635. else
  636. return col_limit_base;
  637. };
  638. auto col_limit_left = [&](int row, int n_block) {
  639. return std::max(
  640. 0,
  641. row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H - mainloop_params.window_size_left
  642. );
  643. };
  644. {
  645. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  646. Tensor tScS = threadMma0.partition_C(cS);
  647. #pragma unroll
  648. for (int i = 0; i < size(tSrS); ++i) {
  649. if constexpr (!Is_causal && !Is_local) { // Just masking based on col
  650. if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
  651. } else { // mask based on both row and col
  652. // using std::min is faster than doing col >= limit0 or col >= limit1
  653. // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the
  654. // right hand side can be negative and might be converted to a very large unsigned integer.
  655. int row = int(get<0>(tScS(i))) / kBlockH;
  656. if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, col_limit_right(row, n_block))) {
  657. tSrS(i) = -INFINITY;
  658. } else if constexpr(Is_local) {
  659. if (int(get<1>(tScS(i))) < col_limit_left(row, n_block)) {
  660. tSrS(i) = -INFINITY;
  661. }
  662. }
  663. }
  664. }
  665. }
  666. softmax.template online_softmax</*Is_first=*/true>(tSrS);
  667. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
  668. Tensor scores_scale = make_fragment_like(softmax.row_max);
  669. clear(scores_scale);
  670. constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM_div_H, kBlockN) + 1;
  671. // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal
  672. #pragma unroll
  673. for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > n_block_min; ++masking_step, --n_block) {
  674. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  675. consumer_wait(pipeline_k, smem_pipe_read_k);
  676. warp_scheduler_barrier_sync();
  677. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  678. if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); }
  679. consumer_wait(pipeline_v, smem_pipe_read_v);
  680. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  681. warp_scheduler_barrier_arrive();
  682. warpgroup_wait<1>();
  683. pipeline_k.consumer_release(smem_pipe_read_k); // release K
  684. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  685. Tensor tScS = threadMma0.partition_C(cS);
  686. #pragma unroll
  687. for (int i = 0; i < size(tSrS); ++i) {
  688. int row = int(get<0>(tScS(i))) / kBlockH;
  689. if (int(get<1>(tScS(i))) >= col_limit_right(row, n_block - 1)) {
  690. tSrS(i) = -INFINITY;
  691. }
  692. }
  693. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS), scores_scale);
  694. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS);
  695. warpgroup_wait<0>();
  696. pipeline_v.consumer_release(smem_pipe_read_v); // release V
  697. ++smem_pipe_read_k;
  698. ++smem_pipe_read_v;
  699. cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
  700. }
  701. #pragma unroll 1
  702. for (; n_block > n_block_min; --n_block) {
  703. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  704. consumer_wait(pipeline_k, smem_pipe_read_k);
  705. warp_scheduler_barrier_sync();
  706. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  707. softmax.rescale_o(tOrO, scores_scale);
  708. consumer_wait(pipeline_v, smem_pipe_read_v);
  709. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  710. warp_scheduler_barrier_arrive();
  711. warpgroup_wait<1>();
  712. pipeline_k.consumer_release(smem_pipe_read_k); // release K
  713. if constexpr(Is_local) {
  714. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  715. Tensor tScS = threadMma0.partition_C(cS);
  716. #pragma unroll
  717. for (int i = 0; i < size(tSrS); ++i) {
  718. int row = int(get<0>(tScS(i))) / kBlockH;
  719. if (
  720. int(get<1>(tScS(i))) >= col_limit_right(row, n_block - 1) ||
  721. int(get<1>(tScS(i))) < col_limit_left(row, n_block - 1)
  722. ) {
  723. tSrS(i) = -INFINITY;
  724. }
  725. }
  726. }
  727. // auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
  728. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS), scores_scale);
  729. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS);
  730. warpgroup_wait<0>();
  731. pipeline_v.consumer_release(smem_pipe_read_v); // release V
  732. ++smem_pipe_read_k;
  733. ++smem_pipe_read_v;
  734. // softmax.rescale_o(tOrO, scores_scale);
  735. cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
  736. }
  737. // Tell warp 0 that smem_q is ready
  738. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  739. softmax.rescale_o(tOrO, scores_scale);
  740. consumer_wait(pipeline_v, smem_pipe_read_v);
  741. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  742. cute::copy(softmax.template finalize</*Is_dropout=*/false, Is_split>(tSrS), scores_scale);
  743. warpgroup_wait<0>();
  744. pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang
  745. ++smem_pipe_read_v;
  746. softmax.rescale_o(tOrO, scores_scale);
  747. return;
  748. }
  749. template <bool Delay_V_release = false, typename SharedStorage, typename FrgTensorO, typename Softmax>
  750. CUTLASS_DEVICE void
  751. mma_fp8(Params const& mainloop_params,
  752. MainloopPipeline pipeline_k,
  753. MainloopPipelineNoTMA pipeline_vt,
  754. PipelineState& smem_pipe_read,
  755. PipelineState& smem_pipe_release,
  756. FrgTensorO& tOrO,
  757. Softmax& softmax,
  758. int n_block_min,
  759. int n_block_max,
  760. int thread_idx,
  761. int work_idx,
  762. int m_block,
  763. SharedStorage& shared_storage,
  764. const Seqlen_traits_Q& seqlen_traits_q,
  765. const Seqlen_traits& seqlen_traits_k
  766. ) {
  767. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  768. // static constexpr int kBlockM = get<0>(TileShape_MNK{});
  769. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  770. static constexpr int kBlockH = Ktraits::kBlockH;
  771. static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{}) / kBlockH;
  772. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  773. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  774. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutVt{});
  775. typename Ktraits::TiledMma0 tiled_mma0;
  776. typename Ktraits::TiledMma1 tiled_mma1;
  777. auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
  778. auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
  779. // Allocate "fragments/descriptors" for first matmul.
  780. Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
  781. Tensor tSrK = threadMma0.partition_fragment_B(sK);
  782. // Allocate "fragments/descriptors" for second matmul.
  783. Tensor tOrV = threadMma1.partition_fragment_B(sVt);
  784. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  785. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  786. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  787. };
  788. tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
  789. int const seqlen_q = seqlen_traits_q.actual_seq_len;
  790. int const seqlen_k = seqlen_traits_k.actual_seq_len;
  791. int n_block = n_block_max - 1;
  792. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
  793. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
  794. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  795. consumer_wait(pipeline_k, smem_pipe_read);
  796. warp_scheduler_barrier_sync();
  797. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  798. if constexpr (!No_smem_O) {
  799. if (work_idx != 0) {
  800. int lane_predicate = cute::elect_one_sync();
  801. if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
  802. tma_store_wait<0>();
  803. #pragma unroll
  804. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  805. shared_storage.barrier_O.arrive(cta_id, lane_predicate);
  806. }
  807. }
  808. }
  809. }
  810. warpgroup_wait<0>();
  811. warp_scheduler_barrier_arrive();
  812. pipeline_k.consumer_release(smem_pipe_read);
  813. auto col_limit_right = [&](int row, int n_block) {
  814. int col_limit_base = row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H;
  815. if constexpr(Is_local)
  816. return col_limit_base + mainloop_params.window_size_right;
  817. else
  818. return col_limit_base;
  819. };
  820. auto col_limit_left = [&](int row, int n_block) {
  821. return std::max(
  822. 0,
  823. row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H - mainloop_params.window_size_left
  824. );
  825. };
  826. {
  827. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  828. Tensor tScS = threadMma0.partition_C(cS);
  829. #pragma unroll
  830. for (int i = 0; i < size(tSrS); ++i) {
  831. if constexpr (!Is_causal && !Is_local) { // Just masking based on col
  832. if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
  833. } else { // mask based on both row and col
  834. int row = int(get<0>(tScS(i))) / kBlockH;
  835. if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, col_limit_right(row, n_block))) {
  836. tSrS(i) = -INFINITY;
  837. } else if constexpr(Is_local) {
  838. if (int(get<1>(tScS(i))) < col_limit_left(row, n_block)) {
  839. tSrS(i) = -INFINITY;
  840. }
  841. }
  842. }
  843. }
  844. }
  845. softmax.template online_softmax</*Is_first=*/true>(tSrS);
  846. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  847. permute_regs_A_to_C(tOrP);
  848. Tensor scores_scale = make_fragment_like(softmax.row_max);
  849. clear(scores_scale);
  850. consumer_wait(pipeline_vt, smem_pipe_read);
  851. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  852. if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
  853. ++smem_pipe_read;
  854. --n_block;
  855. constexpr int extra_iterations = !Is_causal ? kStages - 1 : cute::ceil_div(kBlockM_div_H, kBlockN);
  856. if constexpr(Is_causal) {
  857. CUTLASS_PRAGMA_UNROLL
  858. for (int iter = 0; iter < extra_iterations && n_block >= n_block_min; ++iter, --n_block) {
  859. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  860. consumer_wait(pipeline_k, smem_pipe_read);
  861. warp_scheduler_barrier_sync();
  862. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  863. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  864. Tensor tScS = threadMma0.partition_C(cS);
  865. #pragma unroll
  866. for (int i = 0; i < size(tSrS); ++i) {
  867. int row = int(get<0>(tScS(i))) / kBlockH;
  868. if (int(get<1>(tScS(i))) >= col_limit_right(row, n_block)) {
  869. tSrS(i) = -INFINITY;
  870. }
  871. }
  872. warp_scheduler_barrier_arrive();
  873. pipeline_k.consumer_release(smem_pipe_read);
  874. if constexpr(Delay_V_release) {
  875. pipeline_vt.consumer_release(smem_pipe_release);
  876. ++smem_pipe_release;
  877. }
  878. consumer_wait(pipeline_vt, smem_pipe_read);
  879. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS), scores_scale);
  880. softmax.rescale_o(tOrO, scores_scale);
  881. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS);
  882. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  883. permute_regs_A_to_C(tOrP);
  884. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  885. if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
  886. ++smem_pipe_read;
  887. }
  888. } else if constexpr(!Is_local) {
  889. CUTLASS_PRAGMA_UNROLL
  890. for (int iter = 0; iter < extra_iterations && n_block >= n_block_min; ++iter, --n_block) {
  891. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  892. consumer_wait(pipeline_k, smem_pipe_read);
  893. if constexpr(Delay_V_release) {
  894. pipeline_vt.consumer_release(smem_pipe_release);
  895. ++smem_pipe_release;
  896. }
  897. warp_scheduler_barrier_sync();
  898. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  899. warp_scheduler_barrier_arrive();
  900. if constexpr(!Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
  901. else { consumer_wait(pipeline_vt, smem_pipe_read); }
  902. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS), scores_scale);
  903. softmax.rescale_o(tOrO, scores_scale);
  904. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS);
  905. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  906. permute_regs_A_to_C(tOrP);
  907. if constexpr (Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
  908. else { consumer_wait(pipeline_vt, smem_pipe_read); }
  909. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  910. if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
  911. ++smem_pipe_read;
  912. }
  913. }
  914. if constexpr(Delay_V_release) {
  915. warp_scheduler_barrier_sync();
  916. CUTLASS_PRAGMA_NO_UNROLL
  917. for (; n_block >= n_block_min; --n_block) {
  918. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  919. consumer_wait(pipeline_k, smem_pipe_read);
  920. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  921. if constexpr(Is_local) {
  922. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  923. Tensor tScS = threadMma0.partition_C(cS);
  924. #pragma unroll
  925. for (int i = 0; i < size(tSrS); ++i) {
  926. int row = int(get<0>(tScS(i))) / kBlockH;
  927. if (
  928. int(get<1>(tScS(i))) >= col_limit_right(row, n_block) ||
  929. int(get<1>(tScS(i))) < col_limit_left(row, n_block)
  930. ) {
  931. tSrS(i) = -INFINITY;
  932. }
  933. }
  934. }
  935. warp_scheduler_barrier_arrive();
  936. pipeline_k.consumer_release(smem_pipe_read);
  937. pipeline_vt.consumer_release(smem_pipe_release);
  938. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS), scores_scale);
  939. softmax.rescale_o(tOrO, scores_scale);
  940. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS);
  941. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  942. permute_regs_A_to_C(tOrP);
  943. consumer_wait(pipeline_vt, smem_pipe_read);
  944. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  945. warp_scheduler_barrier_sync();
  946. ++smem_pipe_read;
  947. ++smem_pipe_release;
  948. }
  949. warp_scheduler_barrier_arrive();
  950. pipeline_vt.consumer_release(smem_pipe_release);
  951. ++smem_pipe_release;
  952. } else {
  953. if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
  954. CUTLASS_PRAGMA_NO_UNROLL
  955. for (; n_block >= n_block_min; --n_block) {
  956. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  957. consumer_wait(pipeline_k, smem_pipe_read);
  958. if constexpr (kHeadDim == 256) { warp_scheduler_barrier_sync(); }
  959. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  960. if constexpr(Is_local) {
  961. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  962. Tensor tScS = threadMma0.partition_C(cS);
  963. #pragma unroll
  964. for (int i = 0; i < size(tSrS); ++i) {
  965. int row = int(get<0>(tScS(i))) / kBlockH;
  966. if (
  967. int(get<1>(tScS(i))) >= col_limit_right(row, n_block) ||
  968. int(get<1>(tScS(i))) < col_limit_left(row, n_block)
  969. ) {
  970. tSrS(i) = -INFINITY;
  971. }
  972. }
  973. }
  974. warp_scheduler_barrier_arrive();
  975. pipeline_k.consumer_release(smem_pipe_read);
  976. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS), scores_scale);
  977. softmax.rescale_o(tOrO, scores_scale);
  978. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS);
  979. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  980. permute_regs_A_to_C(tOrP);
  981. consumer_wait(pipeline_vt, smem_pipe_read);
  982. if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
  983. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  984. pipeline_vt.consumer_release(smem_pipe_read);
  985. ++smem_pipe_read;
  986. }
  987. if constexpr (kHeadDim == 128) { warp_scheduler_barrier_arrive(); }
  988. }
  989. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  990. cute::copy(softmax.template finalize</*Is_dropout=*/false, Is_split>(tSrS, shared_storage.descale_v), scores_scale);
  991. softmax.rescale_o(tOrO, scores_scale);
  992. return;
  993. }
  994. };
  995. } // namespace flash