mainloop_fwd_sm90_tma_gmma_ws.hpp 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include "cutlass/pipeline/pipeline.hpp"
  10. #include "cute/tensor.hpp"
  11. #include "cutlass/gemm/collective/collective_builder.hpp"
  12. #include "named_barrier.hpp"
  13. #include "utils.h"
  14. #include "copy_paged_sm90_tma.hpp"
  15. namespace flash {
  16. using namespace cute;
  17. // 4 warps
  18. struct SmemTransposeFp8_64x64 {
  19. using Element = cutlass::float_e4m3_t;
  20. using ldsm_thread_shape = Shape<_4, _1, _8, _4>;
  21. using ldsm_value_shape = Shape<_2, _8, _2, _1>;
  22. using ldsm_value_stride = Stride<_2, _4, _1, _0>;
  23. using TiledCopyLDSM = decltype(make_tiled_copy(
  24. Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<ldsm_thread_shape>{},
  25. Layout<ldsm_value_shape, ldsm_value_stride>{}));
  26. TiledCopyLDSM tiled_copy_ldsm;
  27. using stsm_thread_shape = Shape<_4, _1, _8, _4>;
  28. // using stsm_thread_stride = Stride<_1, _0, _4, _32>;
  29. #ifndef NO_FP8_COLUMN_PERMUTE
  30. using stsm_value_shape = Shape<_4, _4, _1, _2>;
  31. using stsm_value_stride = Stride<_1, _8, _0, _4>;
  32. #else
  33. using stsm_value_shape = Shape<_4, _4, _2, _1>;
  34. using stsm_value_stride = Stride<_1, _8, _4, _0>;
  35. #endif
  36. using TiledCopySTSM =
  37. decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{},
  38. Layout<stsm_thread_shape>{},
  39. Layout<stsm_value_shape, stsm_value_stride>{}));
  40. TiledCopySTSM tiled_copy_stsm;
  41. template <class SmemTensor, class SmemTensorOut>
  42. CUTLASS_DEVICE void operator()(SmemTensor &&s_in, SmemTensorOut &&s_out) {
  43. using namespace cute;
  44. auto tid = threadIdx.x;
  45. auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
  46. auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);
  47. auto tXsX = thr_copy_ldsm.partition_S(s_in);
  48. auto tXrX = make_tensor<Element>(shape(tXsX));
  49. auto tXsX_out = thr_copy_stsm.partition_D(s_out);
  50. cute::copy(tiled_copy_ldsm, tXsX, tXrX);
  51. auto data = tXrX.data();
  52. // size(tXrX) == 32
  53. CUTLASS_PRAGMA_UNROLL
  54. for (int n = 0; n < size(tXrX); n += 8) {
  55. uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
  56. auto upper = data_32bit[0];
  57. auto lower = data_32bit[1];
  58. data_32bit[0] = __byte_perm(upper, lower, 0x6420);
  59. data_32bit[1] = __byte_perm(upper, lower, 0x7531);
  60. }
  61. cute::copy(tiled_copy_stsm, tXrX, tXsX_out);
  62. }
  63. };
  64. template <typename Ktraits, bool Is_causal, bool Is_local, typename Seqlen_traits, typename Seqlen_traits_Q = Seqlen_traits>
  65. struct CollectiveMainloopFwd {
  66. using Element = typename Ktraits::Element;
  67. using TileShape_MNK = typename Ktraits::TileShape_MNK;
  68. using ClusterShape = typename Ktraits::ClusterShape_MNK;
  69. static constexpr int kStages = Ktraits::kStages;
  70. static constexpr int kHeadDim = Ktraits::kHeadDim;
  71. // static constexpr int kBlockM = Ktraits::kBlockM;
  72. // static constexpr int kBlockN = Ktraits::kBlockN;
  73. // static constexpr int kBlockH = Ktraits::kBlockH;
  74. static constexpr bool Is_split = Ktraits::Is_split;
  75. static constexpr bool No_smem_O = Ktraits::No_smem_O;
  76. using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
  77. using GmemTiledCopyKVNopage = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
  78. // use SM90_TMA_LOAD_MULTICAST_PAGED if we would use SM90_TMA_LOAD_MULTICAST in unpaged scenario, otherwise use SM90_TMA_LOAD_PAGED
  79. using GmemTiledCopyKV = typename std::conditional<
  80. std::is_same<GmemTiledCopyKVNopage, cute::SM90_TMA_LOAD_MULTICAST>::value,
  81. SM90_TMA_LOAD_MULTICAST_PAGED,
  82. SM90_TMA_LOAD_PAGED>::type;
  83. using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
  84. using SmemLayoutQCopy = typename Ktraits::SmemLayoutQCopy;
  85. using TileShapeQCopy = typename Ktraits::TileShapeQCopy;
  86. using SmemLayoutK = typename Ktraits::SmemLayoutK;
  87. using SmemLayoutV = typename Ktraits::SmemLayoutV;
  88. using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
  89. using TMA_Q = decltype(make_tma_copy(
  90. GmemTiledCopyQ{},
  91. make_tensor(
  92. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  93. repeat_like(typename Seqlen_traits_Q::StrideT{}, int32_t(0)),
  94. typename Seqlen_traits_Q::StrideT{}
  95. ),
  96. SmemLayoutQCopy{},
  97. TileShapeQCopy{},
  98. _1{})); // no mcast for Q
  99. using TMA_K = decltype(make_virtualized_tma_copy(
  100. GmemTiledCopyKV{},
  101. make_tensor(
  102. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  103. repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
  104. typename Seqlen_traits::StrideT{}
  105. ),
  106. typename Seqlen_traits::ShapeT{},
  107. take<0, 2>(SmemLayoutK{}),
  108. select<1, 2>(TileShape_MNK{}),
  109. size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
  110. // TMA_V may differ from TMA_K for fp8 kernel (e.g. swizzling mode)
  111. using TMA_V = decltype(make_virtualized_tma_copy(
  112. GmemTiledCopyKV{},
  113. make_tensor(
  114. make_gmem_ptr(static_cast<Element const*>(nullptr)),
  115. repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
  116. typename Seqlen_traits::StrideT{}
  117. ),
  118. typename Seqlen_traits::ShapeT{},
  119. take<0, 2>(SmemLayoutV{}),
  120. select<1, 2>(TileShape_MNK{}),
  121. size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
  122. static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
  123. using MainloopPipeline = typename Ktraits::MainloopPipeline;
  124. using MainloopPipelineNoTMA = typename Ktraits::MainloopPipelineNoTMA;
  125. using PipelineParams = typename MainloopPipeline::Params;
  126. using PipelineState = typename MainloopPipeline::PipelineState;
  127. // Set the bytes transferred in this TMA transaction (may involve multiple issues)
  128. static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
  129. static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
  130. // static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
  131. static constexpr bool UseSchedulerBarrier = Ktraits::kNWarps >= 12 &&
  132. (cutlass::sizeof_bits_v<Element> == 8 ? kHeadDim >= 128 : kHeadDim <= 128);
  133. // Host side kernel arguments
  134. struct Arguments {
  135. Element const* ptr_Q;
  136. typename Seqlen_traits_Q::LayoutT layout_Q;
  137. Element const* ptr_K;
  138. typename Seqlen_traits::LayoutT layout_K;
  139. Element const* ptr_V;
  140. typename Seqlen_traits::LayoutT layout_V;
  141. typename Seqlen_traits::ShapeT shape_KV;
  142. float const softmax_scale_log2;
  143. float const* descale_q_ptr;
  144. float const* descale_k_ptr;
  145. float const* descale_v_ptr;
  146. int window_size_left;
  147. int window_size_right;
  148. int const qhead_per_khead;
  149. int const* cache_batch_idx;
  150. int const num_splits;
  151. // Paged Attention block table data
  152. int * block_table; // may be nullptr if not paged
  153. int64_t block_table_batch_stride;
  154. int page_block_size;
  155. int num_blocks;
  156. };
  157. // Device side kernel params
  158. struct Params {
  159. typename Seqlen_traits_Q::LayoutT layout_Q;
  160. typename Seqlen_traits::LayoutT layout_K;
  161. typename Seqlen_traits::LayoutT layout_V;
  162. typename Seqlen_traits::ShapeT shape_KV;
  163. cutlass::FastDivmod qhead_per_khead_divmod;
  164. TMA_Q tma_load_Q;
  165. TMA_K tma_load_K;
  166. TMA_V tma_load_V;
  167. float const softmax_scale_log2;
  168. float const* descale_q_ptr;
  169. float const* descale_k_ptr;
  170. float const* descale_v_ptr;
  171. int window_size_left;
  172. int window_size_right;
  173. int const* cache_batch_idx;
  174. cutlass::FastDivmod num_splits_divmod;
  175. // Paged Attention block table data
  176. const PagedCopyArgs paged_copy_args;
  177. };
  178. static Params
  179. to_underlying_arguments(Arguments const& args) {
  180. Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
  181. TMA_Q tma_load_Q = make_tma_copy(
  182. GmemTiledCopyQ{},
  183. mQ,
  184. SmemLayoutQCopy{},
  185. TileShapeQCopy{},
  186. _1{}); // no mcast for Q
  187. Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
  188. TMA_K tma_load_K = make_virtualized_tma_copy(
  189. GmemTiledCopyKV{},
  190. mK,
  191. args.shape_KV,
  192. SmemLayoutK{}(_, _, _0{}),
  193. select<1, 2>(TileShape_MNK{}),
  194. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  195. Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
  196. TMA_V tma_load_V = make_virtualized_tma_copy(
  197. GmemTiledCopyKV{},
  198. mV,
  199. args.shape_KV,
  200. SmemLayoutV{}(_, _, _0{}),
  201. select<1, 2>(TileShape_MNK{}),
  202. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  203. return {args.layout_Q, args.layout_K, args.layout_V, args.shape_KV,
  204. cutlass::FastDivmod(args.qhead_per_khead),
  205. tma_load_Q, tma_load_K, tma_load_V,
  206. args.softmax_scale_log2,
  207. args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr,
  208. args.window_size_left, args.window_size_right,
  209. args.cache_batch_idx,
  210. cutlass::FastDivmod(args.num_splits),
  211. {args.block_table_batch_stride, args.page_block_size, args.block_table }};
  212. }
  213. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  214. CUTLASS_DEVICE
  215. static void prefetch_tma_descriptors(Params const& mainloop_params) {
  216. cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
  217. cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
  218. cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
  219. }
  220. CUTLASS_DEVICE
  221. void get_n_block_min_max(
  222. Params const& mainloop_params,
  223. int m_block,
  224. int n_split_idx,
  225. const Seqlen_traits_Q& seqlen_traits_q,
  226. const Seqlen_traits& seqlen_traits_k,
  227. int& n_block_min,
  228. int& n_block_max
  229. ) {
  230. // static constexpr int kBlockM = get<0>(TileShape_MNK{});
  231. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  232. static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{})/Ktraits::kBlockH;
  233. int const seqlen_q = seqlen_traits_q.actual_seq_len;
  234. int const seqlen_k = seqlen_traits_k.actual_seq_len;
  235. n_block_max = cute::ceil_div(seqlen_k, kBlockN);
  236. if constexpr(Is_split) {
  237. int const n_blocks_per_split
  238. = mainloop_params.num_splits_divmod.divide(n_block_max + int(mainloop_params.num_splits_divmod) - 1);
  239. n_block_min = n_split_idx * n_blocks_per_split;
  240. n_block_max = std::min(n_block_max, (n_split_idx + 1) * n_blocks_per_split);
  241. }
  242. if constexpr (Is_causal) {
  243. n_block_max = std::min(
  244. n_block_max,
  245. cute::ceil_div((m_block + 1) * kBlockM_div_H + seqlen_k - seqlen_q, kBlockN));
  246. } else if constexpr (Is_local) {
  247. n_block_max = std::min(
  248. n_block_max,
  249. cute::ceil_div((m_block + 1) * kBlockM_div_H + seqlen_k - seqlen_q + mainloop_params.window_size_right, kBlockN));
  250. n_block_min = std::max(
  251. n_block_min,
  252. (m_block * kBlockM_div_H + seqlen_k - seqlen_q - mainloop_params.window_size_left) / kBlockN);
  253. }
  254. }
  255. CUTLASS_DEVICE
  256. void get_n_block_max(
  257. Params const& mainloop_params,
  258. int m_block,
  259. const Seqlen_traits_Q& seqlen_traits_q,
  260. const Seqlen_traits& seqlen_traits_k,
  261. int& n_block_max
  262. ) {
  263. // static constexpr int kBlockM = get<0>(TileShape_MNK{});
  264. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  265. static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{})/Ktraits::kBlockH;
  266. int const seqlen_q = seqlen_traits_q.actual_seq_len;
  267. int const seqlen_k = seqlen_traits_k.actual_seq_len;
  268. n_block_max = cute::ceil_div(seqlen_k, kBlockN);
  269. if constexpr (Is_causal) {
  270. n_block_max = std::min(n_block_max,
  271. cute::ceil_div((m_block + 1) * kBlockM_div_H + seqlen_k - seqlen_q, kBlockN));
  272. }
  273. }
  274. template <typename Scheduler, typename SharedStorage>
  275. CUTLASS_DEVICE void
  276. load(Params const& mainloop_params,
  277. MainloopPipeline pipeline_k,
  278. MainloopPipeline pipeline_v,
  279. PipelineState& smem_pipe_write_k,
  280. PipelineState& smem_pipe_write_v,
  281. SharedStorage &shared_storage,
  282. Scheduler& scheduler,
  283. typename Scheduler::Params const& scheduler_params,
  284. typename Scheduler::WorkTileInfo& work_tile_info,
  285. cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
  286. int work_idx,
  287. const Seqlen_traits_Q& seqlen_traits_q,
  288. const Seqlen_traits& seqlen_traits_k,
  289. int n_block_min,
  290. int n_block_max
  291. ) {
  292. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQCopy{});
  293. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  294. Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
  295. Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
  296. Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_KV);
  297. Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_KV);
  298. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  299. const int bidb_cache = mainloop_params.cache_batch_idx == nullptr ? bidb : mainloop_params.cache_batch_idx[bidb];
  300. const int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
  301. // Prepare the TMA loads
  302. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  303. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  304. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  305. Tensor gQ = [&] {
  306. // Need this inside lambda to capture structured binding
  307. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  308. if constexpr(Seqlen_traits_Q::UseGQAPacking) {
  309. return seqlen_traits_q.get_local_tile_tensor(
  310. mQ, TileShapeQCopy{}, bidh_kv, bidb)
  311. (_, _, _, m_block, bidh % int(mainloop_params.qhead_per_khead_divmod)); // (M/H, H, K)
  312. } else {
  313. return seqlen_traits_q.get_local_tile_tensor(
  314. mQ, TileShapeQCopy{}, bidh, bidb)(_, _, m_block); // (M, K)
  315. }
  316. }();
  317. Tensor gK = seqlen_traits_k.get_local_tile_tensor(
  318. mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _)
  319. Tensor gV = seqlen_traits_k.get_local_tile_tensor(
  320. mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _)
  321. Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
  322. Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
  323. auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
  324. group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
  325. auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
  326. group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE)
  327. auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
  328. group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE)
  329. uint16_t mcast_mask_kv = 0;
  330. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST> || cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST_PAGED>) {
  331. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  332. for (int m = 0; m < size<0>(block_layout); ++m) {
  333. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  334. }
  335. }
  336. int n_block = n_block_max - 1;
  337. int lane_predicate = cute::elect_one_sync();
  338. if (lane_predicate) {
  339. pipeline_k.producer_acquire(smem_pipe_write_k);
  340. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv, mainloop_params.paged_copy_args),
  341. tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
  342. ++smem_pipe_write_k;
  343. }
  344. // Wait for the MMA warpgroups to say that smem_q is ready
  345. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  346. if (lane_predicate) {
  347. shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  348. copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
  349. }
  350. // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem
  351. // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the
  352. // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O.
  353. if constexpr (!No_smem_O) { shared_storage.barrier_O.wait((work_idx + 1) % 2); }
  354. if (lane_predicate) {
  355. // CUTLASS_PRAGMA_NO_UNROLL
  356. #pragma unroll 2
  357. for (; n_block > n_block_min; --n_block) {
  358. pipeline_k.producer_acquire(smem_pipe_write_k);
  359. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv, mainloop_params.paged_copy_args),
  360. tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
  361. ++smem_pipe_write_k;
  362. pipeline_v.producer_acquire(smem_pipe_write_v);
  363. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv, mainloop_params.paged_copy_args),
  364. tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
  365. ++smem_pipe_write_v;
  366. }
  367. }
  368. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  369. if (lane_predicate) {
  370. pipeline_v.producer_acquire(smem_pipe_write_v);
  371. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv, mainloop_params.paged_copy_args),
  372. tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
  373. ++smem_pipe_write_v;
  374. }
  375. scheduler.broadcast_next_work(work_tile_info);
  376. }
  377. template <typename Scheduler, typename SharedStorage>
  378. CUTLASS_DEVICE void
  379. load_fp8(Params const& mainloop_params,
  380. MainloopPipeline pipeline_k,
  381. MainloopPipeline pipeline_v,
  382. MainloopPipelineNoTMA pipeline_vt,
  383. PipelineState& smem_pipe_write,
  384. PipelineState& smem_pipe_read,
  385. SharedStorage &shared_storage,
  386. Scheduler& scheduler,
  387. typename Scheduler::Params const& scheduler_params,
  388. typename Scheduler::WorkTileInfo& work_tile_info,
  389. cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
  390. int work_idx,
  391. const Seqlen_traits_Q& seqlen_traits_q,
  392. const Seqlen_traits& seqlen_traits_k,
  393. int n_block_min,
  394. int n_block_max
  395. ) {
  396. using SmemLayoutTransposeV = typename Ktraits::SmemLayoutTransposeV;
  397. using SmemLayoutTransposeVt = typename Ktraits::SmemLayoutTransposeVt;
  398. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQCopy{});
  399. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  400. Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
  401. Tensor sV_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutTransposeV{}));
  402. Tensor sVt_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutTransposeVt{}));
  403. auto smem_transpose_V = SmemTransposeFp8_64x64();
  404. auto do_transpose_V = [&](int stage) {
  405. CUTLASS_PRAGMA_UNROLL
  406. for (int j = 0; j < shape<2>(SmemLayoutTransposeV{}); ++j) {
  407. CUTLASS_PRAGMA_UNROLL
  408. for (int i = 0; i < shape<1>(SmemLayoutTransposeV{}); ++i) {
  409. smem_transpose_V(flatten(sV_divide(_, i, j, stage)),
  410. flatten(sVt_divide(_, i, j, stage)));
  411. }
  412. }
  413. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
  414. };
  415. Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
  416. Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_KV);
  417. Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_KV);
  418. auto [m_block, split_idx, bidh, bidb] = block_coord;
  419. const int bidb_cache = mainloop_params.cache_batch_idx == nullptr ? bidb : mainloop_params.cache_batch_idx[bidb];
  420. const int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
  421. // Prepare the TMA loads
  422. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  423. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  424. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  425. Tensor gQ = [&] {
  426. // Need this inside lambda to capture structured binding
  427. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  428. if constexpr(Seqlen_traits_Q::UseGQAPacking) {
  429. return seqlen_traits_q.get_local_tile_tensor(
  430. mQ, TileShapeQCopy{}, bidh_kv, bidb)
  431. (_, _, _, m_block, bidh % int(mainloop_params.qhead_per_khead_divmod)); // (M/H, H, K)
  432. } else {
  433. return seqlen_traits_q.get_local_tile_tensor(
  434. mQ, TileShapeQCopy{}, bidh, bidb)(_, _, m_block); // (M, K)
  435. }
  436. }();
  437. Tensor gK = seqlen_traits_k.get_local_tile_tensor(
  438. mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _)
  439. Tensor gV = seqlen_traits_k.get_local_tile_tensor(
  440. mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _)
  441. Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
  442. Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
  443. auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
  444. group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
  445. auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
  446. group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE)
  447. auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
  448. group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE)
  449. uint16_t mcast_mask_kv = 0;
  450. if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST> || cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST_PAGED>) {
  451. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  452. for (int m = 0; m < size<0>(block_layout); ++m) {
  453. mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
  454. }
  455. }
  456. int n_block = n_block_max - 1;
  457. int lane_predicate = cute::elect_one_sync();
  458. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  459. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  460. pipeline_k.producer_acquire(smem_pipe_write);
  461. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, mainloop_params.paged_copy_args),
  462. tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
  463. }
  464. // Wait for the MMA warpgroups to say that smem_q is ready
  465. // for fp8, change from NumThreadsPerWarp to NumThreadsPerWarpGroup
  466. cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  467. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  468. shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
  469. copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
  470. if constexpr(!Ktraits::VO_union_all) {
  471. pipeline_v.producer_acquire(smem_pipe_write);
  472. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv, mainloop_params.paged_copy_args),
  473. tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
  474. }
  475. }
  476. // With fp8 kernel, smem_o is in union with smem_v_out,
  477. // except for split kernel + hdim 256,
  478. // so could use NamedBarrier instead of ClusterBarrier.
  479. // But, this doesn't appear to have any benefit.
  480. if constexpr (!No_smem_O) { shared_storage.barrier_O.wait((work_idx + 1) % 2); }
  481. if constexpr(Ktraits::VO_union_all) {
  482. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  483. pipeline_v.producer_acquire(smem_pipe_write);
  484. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv, mainloop_params.paged_copy_args),
  485. tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
  486. }
  487. }
  488. #pragma unroll 2
  489. for (; n_block > n_block_min; --n_block) {
  490. pipeline_v.consumer_wait(smem_pipe_read);
  491. pipeline_vt.producer_acquire(smem_pipe_write);
  492. do_transpose_V(smem_pipe_read.index());
  493. pipeline_vt.producer_commit(smem_pipe_write);
  494. pipeline_v.consumer_release(smem_pipe_read);
  495. ++smem_pipe_write;
  496. ++smem_pipe_read;
  497. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  498. pipeline_k.producer_acquire(smem_pipe_write);
  499. copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, mainloop_params.paged_copy_args),
  500. tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index()));
  501. pipeline_v.producer_acquire(smem_pipe_write);
  502. copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv, mainloop_params.paged_copy_args),
  503. tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index()));
  504. }
  505. }
  506. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  507. scheduler.broadcast_next_work(work_tile_info);
  508. pipeline_v.consumer_wait(smem_pipe_read);
  509. pipeline_vt.producer_acquire(smem_pipe_write);
  510. do_transpose_V(smem_pipe_read.index());
  511. pipeline_vt.producer_commit(smem_pipe_write);
  512. pipeline_v.consumer_release(smem_pipe_read);
  513. ++smem_pipe_write;
  514. ++smem_pipe_read;
  515. }
  516. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  517. CUTLASS_DEVICE void
  518. load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
  519. PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) {
  520. int lane_predicate = cute::elect_one_sync();
  521. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  522. // Issue the epilogue waits
  523. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  524. /* This helps avoid early exit of blocks in Cluster
  525. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  526. * then would just be acquired since the phase was still inverted from make_producer_start_state
  527. */
  528. pipeline_k.producer_tail(smem_pipe_write_k);
  529. pipeline_v.producer_tail(smem_pipe_write_v);
  530. }
  531. }
  532. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  533. CUTLASS_DEVICE void
  534. load_tail_one_write(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
  535. PipelineState& smem_pipe_write) {
  536. int lane_predicate = cute::elect_one_sync();
  537. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  538. // Issue the epilogue waits
  539. if (warp_idx_in_warpgroup == 0 && lane_predicate) {
  540. /* This helps avoid early exit of blocks in Cluster
  541. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  542. * then would just be acquired since the phase was still inverted from make_producer_start_state
  543. */
  544. pipeline_k.producer_tail(smem_pipe_write);
  545. pipeline_v.producer_tail(smem_pipe_write);
  546. }
  547. }
  548. CUTLASS_DEVICE void
  549. warp_scheduler_barrier_sync() {
  550. if constexpr (UseSchedulerBarrier) {
  551. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
  552. }
  553. }
  554. CUTLASS_DEVICE void
  555. warp_scheduler_barrier_arrive() {
  556. if constexpr (!UseSchedulerBarrier) {
  557. return;
  558. } else {
  559. static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
  560. if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
  561. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
  562. } else {
  563. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
  564. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
  565. }
  566. }
  567. }
  568. CUTLASS_DEVICE void
  569. mma_init() {
  570. // Tell producer (warp 0) that smem_q is ready
  571. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + Ktraits::NumProducerThreads, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  572. if constexpr (!UseSchedulerBarrier) {
  573. return;
  574. } else {
  575. static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
  576. if (cutlass::canonical_warp_group_idx() > 1) {
  577. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
  578. }
  579. if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
  580. if (cutlass::canonical_warp_group_idx() > 2) {
  581. cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
  582. }
  583. }
  584. }
  585. }
  586. template <typename SharedStorage, typename FrgTensorO, typename Softmax>
  587. CUTLASS_DEVICE void
  588. mma(Params const& mainloop_params,
  589. MainloopPipeline pipeline_k,
  590. MainloopPipeline pipeline_v,
  591. PipelineState& smem_pipe_read_k,
  592. PipelineState& smem_pipe_read_v,
  593. FrgTensorO& tOrO,
  594. Softmax& softmax,
  595. int n_block_min,
  596. int n_block_max,
  597. int thread_idx,
  598. int work_idx,
  599. int m_block,
  600. SharedStorage& shared_storage,
  601. const Seqlen_traits_Q& seqlen_traits_q,
  602. const Seqlen_traits& seqlen_traits_k
  603. ) {
  604. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  605. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  606. static constexpr int kBlockH = Ktraits::kBlockH;
  607. static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{}) / kBlockH;
  608. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  609. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  610. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
  611. typename Ktraits::TiledMma0 tiled_mma0;
  612. typename Ktraits::TiledMma1 tiled_mma1;
  613. auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
  614. auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
  615. // Allocate "fragments/descriptors" for first matmul.
  616. Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
  617. Tensor tSrK = threadMma0.partition_fragment_B(sK);
  618. // Allocate "fragments/descriptors" for second matmul.
  619. // Note: S becomes P.
  620. Tensor tOrV = threadMma1.partition_fragment_B(sVt);
  621. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  622. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  623. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  624. };
  625. tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
  626. int const seqlen_q = seqlen_traits_q.actual_seq_len;
  627. int const seqlen_k = seqlen_traits_k.actual_seq_len;
  628. int n_block = n_block_max - 1;
  629. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
  630. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
  631. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  632. consumer_wait(pipeline_k, smem_pipe_read_k);
  633. warp_scheduler_barrier_sync();
  634. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  635. warp_scheduler_barrier_arrive();
  636. if constexpr (!No_smem_O) {
  637. if (work_idx != 0) {
  638. int lane_predicate = cute::elect_one_sync();
  639. if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
  640. tma_store_wait<0>();
  641. #pragma unroll
  642. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  643. shared_storage.barrier_O.arrive(cta_id, lane_predicate);
  644. }
  645. }
  646. }
  647. }
  648. warpgroup_wait<0>();
  649. pipeline_k.consumer_release(smem_pipe_read_k);
  650. ++smem_pipe_read_k;
  651. auto col_limit_right = [&](int row, int n_block) {
  652. int col_limit_base = row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H;
  653. if constexpr(Is_local)
  654. return col_limit_base + mainloop_params.window_size_right;
  655. else
  656. return col_limit_base;
  657. };
  658. auto col_limit_left = [&](int row, int n_block) {
  659. return std::max(
  660. 0,
  661. row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H - mainloop_params.window_size_left
  662. );
  663. };
  664. {
  665. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  666. Tensor tScS = threadMma0.partition_C(cS);
  667. #pragma unroll
  668. for (int i = 0; i < size(tSrS); ++i) {
  669. if constexpr (!Is_causal && !Is_local) { // Just masking based on col
  670. if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
  671. } else { // mask based on both row and col
  672. // using std::min is faster than doing col >= limit0 or col >= limit1
  673. // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the
  674. // right hand side can be negative and might be converted to a very large unsigned integer.
  675. int row = int(get<0>(tScS(i))) / kBlockH;
  676. if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, col_limit_right(row, n_block))) {
  677. tSrS(i) = -INFINITY;
  678. } else if constexpr(Is_local) {
  679. if (int(get<1>(tScS(i))) < col_limit_left(row, n_block)) {
  680. tSrS(i) = -INFINITY;
  681. }
  682. }
  683. }
  684. }
  685. }
  686. softmax.template online_softmax</*Is_first=*/true>(tSrS);
  687. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
  688. Tensor scores_scale = make_fragment_like(softmax.row_max);
  689. clear(scores_scale);
  690. constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM_div_H, kBlockN) + 1;
  691. // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal
  692. #pragma unroll
  693. for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > n_block_min; ++masking_step, --n_block) {
  694. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  695. consumer_wait(pipeline_k, smem_pipe_read_k);
  696. warp_scheduler_barrier_sync();
  697. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  698. if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); }
  699. consumer_wait(pipeline_v, smem_pipe_read_v);
  700. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  701. warp_scheduler_barrier_arrive();
  702. warpgroup_wait<1>();
  703. pipeline_k.consumer_release(smem_pipe_read_k); // release K
  704. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  705. Tensor tScS = threadMma0.partition_C(cS);
  706. #pragma unroll
  707. for (int i = 0; i < size(tSrS); ++i) {
  708. int row = int(get<0>(tScS(i))) / kBlockH;
  709. if (int(get<1>(tScS(i))) >= col_limit_right(row, n_block - 1)) {
  710. tSrS(i) = -INFINITY;
  711. }
  712. }
  713. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS), scores_scale);
  714. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS);
  715. warpgroup_wait<0>();
  716. pipeline_v.consumer_release(smem_pipe_read_v); // release V
  717. ++smem_pipe_read_k;
  718. ++smem_pipe_read_v;
  719. cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
  720. }
  721. #pragma unroll 1
  722. for (; n_block > n_block_min; --n_block) {
  723. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  724. consumer_wait(pipeline_k, smem_pipe_read_k);
  725. warp_scheduler_barrier_sync();
  726. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
  727. softmax.rescale_o(tOrO, scores_scale);
  728. consumer_wait(pipeline_v, smem_pipe_read_v);
  729. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  730. warp_scheduler_barrier_arrive();
  731. warpgroup_wait<1>();
  732. pipeline_k.consumer_release(smem_pipe_read_k); // release K
  733. if constexpr(Is_local) {
  734. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  735. Tensor tScS = threadMma0.partition_C(cS);
  736. #pragma unroll
  737. for (int i = 0; i < size(tSrS); ++i) {
  738. int row = int(get<0>(tScS(i))) / kBlockH;
  739. if (
  740. int(get<1>(tScS(i))) >= col_limit_right(row, n_block - 1) ||
  741. int(get<1>(tScS(i))) < col_limit_left(row, n_block - 1)
  742. ) {
  743. tSrS(i) = -INFINITY;
  744. }
  745. }
  746. }
  747. // auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
  748. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS), scores_scale);
  749. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS);
  750. warpgroup_wait<0>();
  751. pipeline_v.consumer_release(smem_pipe_read_v); // release V
  752. ++smem_pipe_read_k;
  753. ++smem_pipe_read_v;
  754. // softmax.rescale_o(tOrO, scores_scale);
  755. cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
  756. }
  757. // Tell warp 0 that smem_q is ready
  758. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  759. softmax.rescale_o(tOrO, scores_scale);
  760. consumer_wait(pipeline_v, smem_pipe_read_v);
  761. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
  762. cute::copy(softmax.template finalize</*Is_dropout=*/false, Is_split>(tSrS), scores_scale);
  763. warpgroup_wait<0>();
  764. pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang
  765. ++smem_pipe_read_v;
  766. softmax.rescale_o(tOrO, scores_scale);
  767. return;
  768. }
  769. template <bool Delay_V_release = false, typename SharedStorage, typename FrgTensorO, typename Softmax>
  770. CUTLASS_DEVICE void
  771. mma_fp8(Params const& mainloop_params,
  772. MainloopPipeline pipeline_k,
  773. MainloopPipelineNoTMA pipeline_vt,
  774. PipelineState& smem_pipe_read,
  775. PipelineState& smem_pipe_release,
  776. FrgTensorO& tOrO,
  777. Softmax& softmax,
  778. int n_block_min,
  779. int n_block_max,
  780. int thread_idx,
  781. int work_idx,
  782. int m_block,
  783. SharedStorage& shared_storage,
  784. const Seqlen_traits_Q& seqlen_traits_q,
  785. const Seqlen_traits& seqlen_traits_k
  786. ) {
  787. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  788. // static constexpr int kBlockM = get<0>(TileShape_MNK{});
  789. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  790. static constexpr int kBlockH = Ktraits::kBlockH;
  791. static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{}) / kBlockH;
  792. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
  793. Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
  794. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutVt{});
  795. typename Ktraits::TiledMma0 tiled_mma0;
  796. typename Ktraits::TiledMma1 tiled_mma1;
  797. auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
  798. auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
  799. // Allocate "fragments/descriptors" for first matmul.
  800. Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
  801. Tensor tSrK = threadMma0.partition_fragment_B(sK);
  802. // Allocate "fragments/descriptors" for second matmul.
  803. Tensor tOrV = threadMma1.partition_fragment_B(sVt);
  804. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  805. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  806. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  807. };
  808. tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
  809. int const seqlen_q = seqlen_traits_q.actual_seq_len;
  810. int const seqlen_k = seqlen_traits_k.actual_seq_len;
  811. int n_block = n_block_max - 1;
  812. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
  813. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
  814. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  815. consumer_wait(pipeline_k, smem_pipe_read);
  816. warp_scheduler_barrier_sync();
  817. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  818. if constexpr (!No_smem_O) {
  819. if (work_idx != 0) {
  820. int lane_predicate = cute::elect_one_sync();
  821. if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
  822. tma_store_wait<0>();
  823. #pragma unroll
  824. for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
  825. shared_storage.barrier_O.arrive(cta_id, lane_predicate);
  826. }
  827. }
  828. }
  829. }
  830. warpgroup_wait<0>();
  831. warp_scheduler_barrier_arrive();
  832. pipeline_k.consumer_release(smem_pipe_read);
  833. auto col_limit_right = [&](int row, int n_block) {
  834. int col_limit_base = row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H;
  835. if constexpr(Is_local)
  836. return col_limit_base + mainloop_params.window_size_right;
  837. else
  838. return col_limit_base;
  839. };
  840. auto col_limit_left = [&](int row, int n_block) {
  841. return std::max(
  842. 0,
  843. row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H - mainloop_params.window_size_left
  844. );
  845. };
  846. {
  847. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  848. Tensor tScS = threadMma0.partition_C(cS);
  849. #pragma unroll
  850. for (int i = 0; i < size(tSrS); ++i) {
  851. if constexpr (!Is_causal && !Is_local) { // Just masking based on col
  852. if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
  853. } else { // mask based on both row and col
  854. int row = int(get<0>(tScS(i))) / kBlockH;
  855. if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, col_limit_right(row, n_block))) {
  856. tSrS(i) = -INFINITY;
  857. } else if constexpr(Is_local) {
  858. if (int(get<1>(tScS(i))) < col_limit_left(row, n_block)) {
  859. tSrS(i) = -INFINITY;
  860. }
  861. }
  862. }
  863. }
  864. }
  865. softmax.template online_softmax</*Is_first=*/true>(tSrS);
  866. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  867. permute_regs_A_to_C(tOrP);
  868. Tensor scores_scale = make_fragment_like(softmax.row_max);
  869. clear(scores_scale);
  870. consumer_wait(pipeline_vt, smem_pipe_read);
  871. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  872. if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
  873. ++smem_pipe_read;
  874. --n_block;
  875. constexpr int extra_iterations = !Is_causal ? kStages - 1 : cute::ceil_div(kBlockM_div_H, kBlockN);
  876. if constexpr(Is_causal) {
  877. CUTLASS_PRAGMA_UNROLL
  878. for (int iter = 0; iter < extra_iterations && n_block >= n_block_min; ++iter, --n_block) {
  879. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  880. consumer_wait(pipeline_k, smem_pipe_read);
  881. warp_scheduler_barrier_sync();
  882. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  883. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  884. Tensor tScS = threadMma0.partition_C(cS);
  885. #pragma unroll
  886. for (int i = 0; i < size(tSrS); ++i) {
  887. int row = int(get<0>(tScS(i))) / kBlockH;
  888. if (int(get<1>(tScS(i))) >= col_limit_right(row, n_block)) {
  889. tSrS(i) = -INFINITY;
  890. }
  891. }
  892. warp_scheduler_barrier_arrive();
  893. pipeline_k.consumer_release(smem_pipe_read);
  894. if constexpr(Delay_V_release) {
  895. pipeline_vt.consumer_release(smem_pipe_release);
  896. ++smem_pipe_release;
  897. }
  898. consumer_wait(pipeline_vt, smem_pipe_read);
  899. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS), scores_scale);
  900. softmax.rescale_o(tOrO, scores_scale);
  901. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS);
  902. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  903. permute_regs_A_to_C(tOrP);
  904. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  905. if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
  906. ++smem_pipe_read;
  907. }
  908. } else if constexpr(!Is_local) {
  909. CUTLASS_PRAGMA_UNROLL
  910. for (int iter = 0; iter < extra_iterations && n_block >= n_block_min; ++iter, --n_block) {
  911. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  912. consumer_wait(pipeline_k, smem_pipe_read);
  913. if constexpr(Delay_V_release) {
  914. pipeline_vt.consumer_release(smem_pipe_release);
  915. ++smem_pipe_release;
  916. }
  917. warp_scheduler_barrier_sync();
  918. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  919. warp_scheduler_barrier_arrive();
  920. if constexpr(!Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
  921. else { consumer_wait(pipeline_vt, smem_pipe_read); }
  922. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS), scores_scale);
  923. softmax.rescale_o(tOrO, scores_scale);
  924. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS);
  925. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  926. permute_regs_A_to_C(tOrP);
  927. if constexpr (Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
  928. else { consumer_wait(pipeline_vt, smem_pipe_read); }
  929. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  930. if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
  931. ++smem_pipe_read;
  932. }
  933. }
  934. if constexpr(Delay_V_release) {
  935. warp_scheduler_barrier_sync();
  936. CUTLASS_PRAGMA_NO_UNROLL
  937. for (; n_block >= n_block_min; --n_block) {
  938. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  939. consumer_wait(pipeline_k, smem_pipe_read);
  940. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  941. if constexpr(Is_local) {
  942. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  943. Tensor tScS = threadMma0.partition_C(cS);
  944. #pragma unroll
  945. for (int i = 0; i < size(tSrS); ++i) {
  946. int row = int(get<0>(tScS(i))) / kBlockH;
  947. if (
  948. int(get<1>(tScS(i))) >= col_limit_right(row, n_block) ||
  949. int(get<1>(tScS(i))) < col_limit_left(row, n_block)
  950. ) {
  951. tSrS(i) = -INFINITY;
  952. }
  953. }
  954. }
  955. warp_scheduler_barrier_arrive();
  956. pipeline_k.consumer_release(smem_pipe_read);
  957. pipeline_vt.consumer_release(smem_pipe_release);
  958. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS), scores_scale);
  959. softmax.rescale_o(tOrO, scores_scale);
  960. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS);
  961. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  962. permute_regs_A_to_C(tOrP);
  963. consumer_wait(pipeline_vt, smem_pipe_read);
  964. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  965. warp_scheduler_barrier_sync();
  966. ++smem_pipe_read;
  967. ++smem_pipe_release;
  968. }
  969. warp_scheduler_barrier_arrive();
  970. pipeline_vt.consumer_release(smem_pipe_release);
  971. ++smem_pipe_release;
  972. } else {
  973. if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
  974. CUTLASS_PRAGMA_NO_UNROLL
  975. for (; n_block >= n_block_min; --n_block) {
  976. Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
  977. consumer_wait(pipeline_k, smem_pipe_read);
  978. if constexpr (kHeadDim == 256) { warp_scheduler_barrier_sync(); }
  979. flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
  980. if constexpr(Is_local) {
  981. Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
  982. Tensor tScS = threadMma0.partition_C(cS);
  983. #pragma unroll
  984. for (int i = 0; i < size(tSrS); ++i) {
  985. int row = int(get<0>(tScS(i))) / kBlockH;
  986. if (
  987. int(get<1>(tScS(i))) >= col_limit_right(row, n_block) ||
  988. int(get<1>(tScS(i))) < col_limit_left(row, n_block)
  989. ) {
  990. tSrS(i) = -INFINITY;
  991. }
  992. }
  993. }
  994. warp_scheduler_barrier_arrive();
  995. pipeline_k.consumer_release(smem_pipe_read);
  996. cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS), scores_scale);
  997. softmax.rescale_o(tOrO, scores_scale);
  998. softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS);
  999. Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
  1000. permute_regs_A_to_C(tOrP);
  1001. consumer_wait(pipeline_vt, smem_pipe_read);
  1002. if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
  1003. flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  1004. pipeline_vt.consumer_release(smem_pipe_read);
  1005. ++smem_pipe_read;
  1006. }
  1007. if constexpr (kHeadDim == 128) { warp_scheduler_barrier_arrive(); }
  1008. }
  1009. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
  1010. cute::copy(softmax.template finalize</*Is_dropout=*/false, Is_split>(tSrS, shared_storage.descale_v), scores_scale);
  1011. softmax.rescale_o(tOrO, scores_scale);
  1012. return;
  1013. }
  1014. };
  1015. } // namespace flash