mainloop_fwd_sm80.hpp 49 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include "cute/tensor.hpp"
  10. #include "seqlen.h"
  11. #include "block.h"
  12. #include "mask.h"
  13. #include "pack_gqa.h"
  14. #include "paged_kv.h"
  15. #include "rotary.h"
  16. #include "utils.h"
  17. namespace flash {
  18. using namespace cute;
  19. template <int kNWarps, int Stages, bool Q_in_regs, class TileShape_MNK_, int kHeadDimV, class Element_, class ElementAccum_, class ArchTag_,
  20. bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool PagedKV_, bool AppendKV_,
  21. bool PackGQA_, bool Split_>
  22. struct CollectiveMainloopFwdSm80 {
  23. static constexpr int kStages = Stages;
  24. static_assert(kStages > 0, "kStages must be greater than 0");
  25. using TileShape_MNK = TileShape_MNK_;
  26. using TileShape_MNK_PV = Shape<decltype(get<0>(TileShape_MNK{})), Int<kHeadDimV>, decltype(get<1>(TileShape_MNK{}))>;
  27. using Element = Element_;
  28. using ElementAccum = ElementAccum_;
  29. using ArchTag = ArchTag_;
  30. static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;;
  31. static constexpr bool Is_causal = Is_causal_;
  32. static constexpr bool Is_local = Is_local_;
  33. static constexpr bool Has_softcap = Has_softcap_;
  34. static constexpr bool Varlen = Varlen_;
  35. static constexpr bool PagedKV = PagedKV_;
  36. static constexpr bool AppendKV = AppendKV_;
  37. static constexpr bool PackGQA = PackGQA_;
  38. static constexpr bool Split = Split_;
  39. static constexpr bool Transpose_V = Is_FP8;
  40. static_assert(ArchTag::kMinComputeCapability >= 80);
  41. static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;
  42. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  43. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  44. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  45. using SeqlenInfo_t = flash::SeqlenInfoQKNewK<Varlen, AppendKV>;
  46. using BlockMN_t = flash::BlockMN<SeqlenInfo_t, kBlockM, kBlockN, Is_causal, Is_local, PackGQA, Split>;
  47. using MMA_Atom_Arch = std::conditional_t<
  48. ArchTag::kMinComputeCapability >= 80,
  49. std::conditional_t<
  50. std::is_same_v<Element, cutlass::half_t>,
  51. MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
  52. MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
  53. >,
  54. MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>
  55. >;
  56. using TiledMma = TiledMMA<
  57. MMA_Atom_Arch,
  58. Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
  59. Tile<Int<16 * kNWarps>, _16, _16>>;
  60. static constexpr int NumMmaThreads = size(TiledMma{});
  61. static constexpr int NumProducerThreads = NumMmaThreads; // For compatibility with TileScheduler
  62. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  63. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  64. // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each
  65. // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.
  66. static constexpr int kBytePerRow = kHeadDim * sizeof(Element);
  67. static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
  68. static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
  69. static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
  70. using SmemLayoutAtomQKV = decltype(
  71. composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
  72. Layout<Shape<_8, Int<kBlockKGmem>>,
  73. Stride<Int<kBlockKGmem>, _1>>{}));
  74. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQKV{}, select<0, 2>(TileShape_MNK{})));
  75. using SmemLayoutK = decltype(tile_to_shape(
  76. SmemLayoutAtomQKV{},
  77. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  78. using SmemLayoutV = decltype(tile_to_shape(
  79. SmemLayoutAtomQKV{},
  80. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  81. using SmemLayoutVt = decltype(
  82. composition(SmemLayoutV{},
  83. make_ordered_layout(make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{}),
  84. Step<_2, _1, _3>{})));
  85. using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, Element>;
  86. using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, Element>;
  87. // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
  88. // from the same address by the same threadblock. This is slightly faster.
  89. using GmemCopyAtom = Copy_Atom<std::conditional_t<
  90. Has_cp_async,
  91. SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>,
  92. AutoVectorizingCopyWithAssumedAlignment<128>
  93. >, Element>;
  94. static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
  95. static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow");
  96. using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  97. Stride<Int<kGmemThreadsPerRow>, _1>>;
  98. using GmemTiledCopyQKV = decltype(
  99. make_tiled_copy(GmemCopyAtom{},
  100. GmemLayoutAtom{},
  101. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per read
  102. // So that we don't have to check if we overshot kBlockM when we load Q
  103. static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);
  104. // For AppendKV, We want each thread to have at least 2 loads in the K direction since in the case of
  105. // non-interleaved rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc),
  106. // each thread will load twice from the same row.
  107. static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element);
  108. static constexpr int kBlockKGmemAppend = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
  109. static constexpr int kGmemThreadsPerRowAppend = kBlockKGmemAppend / kGmemElemsPerLoad;
  110. static_assert(NumMmaThreads % kGmemThreadsPerRowAppend == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRowAppend");
  111. // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where
  112. // these threads share the same page table entry and share the work of computing pointers to paged K and paged V.
  113. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRowAppend == 0, "kGmemThreadsPerRowAppend must divide NumThreadsPerWarp");
  114. using GmemLayoutAtomAppend = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRowAppend>, Int<kGmemThreadsPerRowAppend>>,
  115. Stride<Int<kGmemThreadsPerRowAppend>, _1>>;
  116. // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication
  117. static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtomAppend{})) == 0, "kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRowAppend");
  118. using GmemTiledCopyAppendKV = decltype(
  119. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  120. GmemLayoutAtomAppend{},
  121. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
  122. using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
  123. using StrideQK = cute::Stride<int64_t, _1, int64_t, int64_t>;
  124. using StrideV = StrideQK;
  125. // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
  126. using ShapeQPacked = std::conditional_t<!PackGQA, ShapeQKV, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
  127. using StrideQPacked = std::conditional_t<!PackGQA, StrideQK, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t>>;
  128. using ShapePageTable = cute::Shape<int32_t, int32_t>; // (batch, max_num_pages_per_seq)
  129. using StridePageTable = cute::Stride<int64_t, _1>;
  130. using ShapeRotary = cute::Shape<int32_t, int32_t>; // (seqlen_ro, rotary_dim // 2)
  131. using StrideRotary = cute::Stride<int64_t, _1>;
  132. using StrideDescale = cute::Stride<int64_t, int64_t>;
  133. static constexpr bool Share_QV_Smem = Q_in_regs;
  134. struct TensorStorageSharedQV : cute::aligned_struct<128> {
  135. union {
  136. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  137. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  138. };
  139. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  140. };
  141. struct TensorStorageSeparateQV : cute::aligned_struct<128> {
  142. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  143. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  144. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  145. };
  146. using TensorStorage = std::conditional_t<Share_QV_Smem, TensorStorageSharedQV, TensorStorageSeparateQV>;
  147. // Host side kernel arguments
  148. struct Arguments {
  149. Element const* const ptr_Q;
  150. ShapeQKV const shape_Q;
  151. StrideQK const stride_Q;
  152. Element* const ptr_K; // Not Element const* since we might append to KV cache in-place
  153. ShapeQKV const shape_K;
  154. StrideQK const stride_K;
  155. Element* const ptr_V;
  156. int32_t const headdim_v;
  157. StrideV const stride_V;
  158. Element const* const ptr_K_new;
  159. ShapeQKV const shape_K_new;
  160. StrideQK const stride_K_new;
  161. Element const* const ptr_V_new;
  162. StrideV const stride_V_new;
  163. Element const* const ptr_Qv;
  164. StrideQK const stride_Qv;
  165. Element const* const ptr_rotary_cos;
  166. ShapeRotary const shape_rotary;
  167. StrideRotary const stride_rotary_cos;
  168. Element const* const ptr_rotary_sin;
  169. StrideRotary const stride_rotary_sin;
  170. bool const is_rotary_interleaved;
  171. int const* const ptr_pagetable;
  172. ShapePageTable const shape_pagetable;
  173. StridePageTable const stride_pagetable;
  174. float const softmax_scale;
  175. float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale;
  176. StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale;
  177. int const window_size_left = -1, window_size_right = -1;
  178. float const softcap_val;
  179. int const num_splits;
  180. int const* const kv_batch_idx = nullptr;
  181. int const* const cu_seqlens_q = nullptr;
  182. int const* const cu_seqlens_k = nullptr;
  183. int const* const cu_seqlens_k_new = nullptr;
  184. int const* const seqused_q = nullptr;
  185. int const* const seqused_k = nullptr;
  186. int const* const leftpad_k = nullptr;
  187. };
  188. // Device side kernel params
  189. struct Params {
  190. Element const* const ptr_Q;
  191. ShapeQKV const shape_Q;
  192. StrideQK const stride_Q;
  193. ShapeQPacked const shape_Q_packed;
  194. StrideQPacked const stride_Q_packed;
  195. Element* const ptr_K;
  196. ShapeQKV const shape_K;
  197. StrideQK const stride_K;
  198. Element* const ptr_V;
  199. int32_t const headdim_v;
  200. StrideV const stride_V;
  201. Element const* const ptr_K_new;
  202. ShapeQKV const shape_K_new;
  203. StrideQK const stride_K_new;
  204. Element const* const ptr_V_new;
  205. StrideV const stride_V_new;
  206. Element const* const ptr_rotary_cos;
  207. ShapeRotary const shape_rotary;
  208. StrideRotary const stride_rotary_cos;
  209. Element const* const ptr_rotary_sin;
  210. StrideRotary const stride_rotary_sin;
  211. bool const is_rotary_interleaved;
  212. int const* const ptr_pagetable;
  213. ShapePageTable const shape_pagetable;
  214. StridePageTable const stride_pagetable;
  215. cutlass::FastDivmod page_size_divmod;
  216. cutlass::FastDivmod qhead_per_khead_divmod;
  217. float const softmax_scale_log2;
  218. float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale;
  219. StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale;
  220. float const softcap_val;
  221. int const window_size_left, window_size_right;
  222. int const num_splits;
  223. int const* const kv_batch_idx = nullptr;
  224. int const* const cu_seqlens_q = nullptr;
  225. int const* const cu_seqlens_k = nullptr;
  226. int const* const cu_seqlens_k_new = nullptr;
  227. int const* const seqused_q = nullptr;
  228. int const* const seqused_k = nullptr;
  229. int const* const leftpad_k = nullptr;
  230. };
  231. static Params
  232. to_underlying_arguments(Arguments const& args) {
  233. // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size)
  234. int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K));
  235. auto const shape_Q_packed = cute::conditional_return<!PackGQA>(
  236. args.shape_Q,
  237. make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q))
  238. );
  239. auto const stride_Q_packed = cute::conditional_return<!PackGQA>(
  240. args.stride_Q,
  241. make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q))
  242. );
  243. if (get<1>(args.shape_rotary) > 0) {
  244. assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr);
  245. }
  246. assert(args.num_splits >= 1);
  247. // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val.
  248. // Right after this, we multiply by log2(e) before applying exp2.
  249. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val
  250. // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e)
  251. // (assigning it to params.softmax_scale_log2).
  252. return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed,
  253. args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V,
  254. args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new,
  255. args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos,
  256. args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved,
  257. args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable,
  258. cutlass::FastDivmod(int(get<0>(args.shape_K))),
  259. cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
  260. !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E),
  261. args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale,
  262. args.stride_q_descale, args.stride_k_descale, args.stride_v_descale,
  263. !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val,
  264. args.window_size_left, args.window_size_right,
  265. !Split ? 1 : args.num_splits,
  266. args.kv_batch_idx,
  267. args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new,
  268. args.seqused_q, args.seqused_k, args.leftpad_k};
  269. }
  270. template <typename SharedStorage, typename FrgTensorO, typename Softmax>
  271. CUTLASS_DEVICE bool
  272. mma(Params const& params,
  273. FrgTensorO& tOrO,
  274. Softmax& softmax,
  275. int const thread_idx,
  276. SeqlenInfo_t const& seqlen_info,
  277. cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
  278. SharedStorage& shared_storage
  279. ) {
  280. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  281. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  282. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  283. // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda
  284. int const m_block = get<0>(block_coord);
  285. int const bidh = get<1>(block_coord);
  286. int const bidb = get<2>(block_coord);
  287. int const split_idx = get<3>(block_coord);
  288. int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;
  289. auto n_block_min_max = BlockMN_t::get_n_block_min_max(
  290. seqlen_info, m_block, bidb, split_idx, params.num_splits,
  291. params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod);
  292. int const n_block_min = get<0>(n_block_min_max);
  293. int const n_block_max = get<1>(n_block_min_max);
  294. // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier
  295. if constexpr (Is_causal || Is_local || Varlen || Split) {
  296. if (n_block_max <= n_block_min) { return false; }
  297. }
  298. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});
  299. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  300. Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});
  301. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{});
  302. bool const is_varlen_q = Varlen && params.cu_seqlens_q;
  303. bool const is_varlen_k = Varlen && params.cu_seqlens_k;
  304. int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb];
  305. Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0);
  306. Tensor gQ = local_tile(mQ, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  307. Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K + seqlen_info.offset_k * get<0>(params.stride_K)), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);
  308. Tensor gK = local_tile(mK, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  309. Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V + seqlen_info.offset_k * get<0>(params.stride_V)), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);
  310. Tensor gV = local_tile(mV, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  311. GmemTiledCopyQKV gmem_tiled_copy_QKV;
  312. auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx);
  313. auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{}); // For index calculation
  314. Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN)
  315. Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
  316. Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN)
  317. Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
  318. TiledMma tiled_mma;
  319. auto thr_mma = tiled_mma.get_slice(thread_idx);
  320. // Allocate "fragments/descriptors"
  321. Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
  322. // Copy Atom retiling
  323. auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
  324. auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx);
  325. auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);
  326. auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(thread_idx);
  327. auto smem_tiled_copy_V = make_tiled_copy_B(SmemCopyAtomTransposed{}, tiled_mma);
  328. auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx);
  329. Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
  330. Tensor tSsK = smem_thr_copy_K.partition_S(sK);
  331. Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
  332. // Predicates
  333. Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{}));
  334. Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);
  335. Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV);
  336. Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
  337. #pragma unroll
  338. for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); }
  339. int const seqlen_q = seqlen_info.seqlen_q;
  340. int const seqlen_k = seqlen_info.seqlen_k;
  341. int n_block = n_block_max - 1;
  342. // Prologue: load Q, K, V
  343. // If persistent, we don't need to wait for the previous work_idx to finish
  344. // since we assume that all MMA threads sync in the epilogue before writing to smem_o.
  345. // So any thread gets there, all threads must have finished the previous MMA and at least started
  346. // writing to smem_o.
  347. // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v
  348. if constexpr (Share_QV_Smem) { __syncthreads(); }
  349. if constexpr (!PackGQA) {
  350. Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
  351. Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
  352. Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
  353. Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);
  354. Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ);
  355. Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
  356. #pragma unroll
  357. for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); }
  358. // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit
  359. // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time.
  360. // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
  361. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/true>(
  362. gmem_tiled_copy_QKV, tQgQ, tQsQ, t0QcQ, tQpQ, seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{}))
  363. );
  364. } else {
  365. using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element>;
  366. PackGQAt::load_Q(mQ, sQ, params.qhead_per_khead_divmod, thread_idx, seqlen_q, m_block);
  367. }
  368. cute::cp_async_fence();
  369. using PagedKVManager_t = PagedKVManager<get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>;
  370. PagedKVManager_t paged_kv_manager(
  371. params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable,
  372. params.ptr_K, params.shape_K, params.stride_K,
  373. params.ptr_V, params.headdim_v, params.stride_V,
  374. params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k
  375. );
  376. auto load_K = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) {
  377. static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;
  378. if constexpr (!PagedKV) {
  379. // Do we need bound check to make sure the row doesn't go above kBlockN
  380. static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;
  381. Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_write);
  382. // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit
  383. // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time.
  384. int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN
  385. ? seqlen_info.seqlen_k - n_block * kBlockN
  386. : (!Seqlenk_mask ? kBlockN : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN)));
  387. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
  388. flash::copy</*Is_even_MN=*/!Seqlenk_mask && EvenN, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/true>(
  389. gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK_cur, t0KVcKV, tKVpKV, seqlenk_row_limit);
  390. } else {
  391. paged_kv_manager.template load_page_table<Seqlenk_mask>(n_block);
  392. paged_kv_manager.template load_K<Seqlenk_mask>(n_block, sK(_, _, smem_pipe_write));
  393. }
  394. };
  395. auto load_V = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) {
  396. static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;
  397. if constexpr (!PagedKV) {
  398. // Do we need bound check to make sure the row doesn't go above kBlockN
  399. static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;
  400. Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_write);
  401. // We don't call flash::copy since it doesn't support bound checking
  402. // to not overshot kBlockN when writing to smem.
  403. Tensor tVgV_cur = tVgV(_, _, _, n_block);
  404. int const seqlenk_row_limit = seqlen_info.seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{}));
  405. #pragma unroll
  406. for (int m = 0; m < size<1>(tVsV); ++m) {
  407. // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked
  408. if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) {
  409. bool const predicate_n = !Seqlenk_mask || get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit;
  410. #pragma unroll
  411. for (int k = 0; k < size<2>(tVsV); ++k) {
  412. cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV_cur(_, m, k), tVsV_cur(_, m, k));
  413. }
  414. }
  415. }
  416. } else {
  417. paged_kv_manager.template load_V<Seqlenk_mask>(n_block, sV(_, _, smem_pipe_write));
  418. }
  419. };
  420. auto preprocess_Q = [&] {
  421. if constexpr (!AppendKV) {
  422. flash::cp_async_wait<Share_QV_Smem ? 1 : kStages * 2 - 1>();
  423. } else {
  424. if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q
  425. int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k;
  426. using Rotary_t = Rotary<kBlockM, kHeadDim, NumMmaThreads, Element, !(Is_causal || Is_local) /*FixedPosition*/>;
  427. Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos,
  428. params.ptr_rotary_sin, params.stride_rotary_sin,
  429. params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary);
  430. int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
  431. if (params.is_rotary_interleaved) {
  432. auto [tRrCos, tRrSin] = cute::conditional_return<!PackGQA>(
  433. rotary.template load_cos_sin<true /*kInterleaved*/>(m_block),
  434. rotary.template load_cos_sin_packgqa<true /*kInterleaved*/>(m_block, params.qhead_per_khead_divmod)
  435. );
  436. flash::cp_async_wait<Share_QV_Smem ? 1 : kStages * 2 - 1>();
  437. __syncthreads();
  438. rotary.apply_Q_interleaved(sQ, tRrCos, tRrSin, m_block, qhead_per_khead);
  439. } else {
  440. auto [tRrCosCont, tRrSinCont] = cute::conditional_return<!PackGQA>(
  441. rotary.template load_cos_sin<false /*kInterleaved*/>(m_block),
  442. rotary.template load_cos_sin_packgqa<false /*kInterleaved*/>(m_block, params.qhead_per_khead_divmod)
  443. );
  444. flash::cp_async_wait<Share_QV_Smem ? 1 : kStages * 2 - 1>();
  445. __syncthreads();
  446. rotary.apply_Q_contiguous(sQ, tRrCosCont, tRrSinCont, m_block, qhead_per_khead);
  447. }
  448. } else {
  449. flash::cp_async_wait<Share_QV_Smem ? 1 : kStages * 2 - 1>();
  450. }
  451. }
  452. if constexpr (Q_in_regs) {
  453. __syncthreads();
  454. Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
  455. Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(sQ);
  456. cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view);
  457. }
  458. };
  459. // If Share_QV_Smem, we load Q, then load 1 stage of K, then (optionally) rotate Q and
  460. // read from smem_q to registers, then load V.
  461. // If !Share_QV, Smem, we load Q, load all stages of K & V, then (optionally) rotate Q.
  462. if constexpr (Share_QV_Smem) {
  463. load_K(n_block, 0, cute::true_type{} /*Seqlenk_mask*/);
  464. cute::cp_async_fence();
  465. preprocess_Q();
  466. __syncthreads(); // Make sure all threads have read smem_q before loading V
  467. }
  468. // For persistent, make sure all threads have finished reading smem_o
  469. if constexpr (!Share_QV_Smem) { __syncthreads(); }
  470. // Note, using the for_each() function here to ensure `stage` is of type Int<x>.
  471. for_each(make_int_sequence<kStages>{}, [&] (auto stage) {
  472. static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0;
  473. static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1;
  474. if constexpr (!Share_QV_Smem || !Is_first_stage) {
  475. if (Is_first_stage || n_block - stage >= n_block_min) {
  476. load_K(n_block - stage, stage, cute::bool_constant<Is_first_stage>{} /*Seqlenk_mask*/);
  477. }
  478. // We want the fence outside the if statement to have a fixed number of cp.async commits.
  479. // so that we can wait with the correct number of outstanding commits.
  480. cute::cp_async_fence();
  481. }
  482. if constexpr (!Is_last_stage) {
  483. if (Is_first_stage || n_block - stage >= n_block_min) {
  484. load_V(n_block - stage, stage, cute::bool_constant<Is_first_stage>{} /*Seqlenk_mask*/);
  485. }
  486. cute::cp_async_fence();
  487. }
  488. });
  489. if constexpr (!Share_QV_Smem) { preprocess_Q(); }
  490. flash::Mask<kBlockM, kBlockN, PackGQA, TiledMma> mask(
  491. thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/,
  492. params.qhead_per_khead_divmod
  493. );
  494. float softcap_val = params.softcap_val;
  495. if constexpr (Has_softcap && Is_FP8) {
  496. float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)];
  497. float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)];
  498. softcap_val *= q_descale * k_descale;
  499. }
  500. // Softcapping needs to happen before masking since if we apply after masking, softcapping can turn
  501. // -inf to e.g. -50.0, which can affect the attention softmax.
  502. auto scoremod_premask_fn = [&](auto& tSrS) {
  503. if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); }
  504. };
  505. int smem_pipe_read = 0, smem_pipe_write = kStages - 1;
  506. auto load_K_next = [&] {
  507. if (n_block - kStages >= n_block_min) {
  508. load_K(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/);
  509. }
  510. cute::cp_async_fence();
  511. };
  512. auto sync = [&] {
  513. flash::cp_async_wait<kStages * 2 - 2>();
  514. __syncthreads();
  515. };
  516. clear(tOrO);
  517. auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) {
  518. static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value;
  519. static constexpr bool Check_inf = decltype(check_inf_type)::value;
  520. Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));
  521. clear(tSrS);
  522. sync();
  523. auto load_V_next = [&] {
  524. if (n_block - kStages + 1 >= n_block_min) {
  525. load_V(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant<Is_first_iter && kStages == 1>{} /*Seqlenk_mask*/);
  526. }
  527. cute::cp_async_fence();
  528. };
  529. Tensor tSrQ_cur = cute::conditional_return<Q_in_regs>(tSrQ, thr_mma.partition_fragment_A(sQ));
  530. Tensor tSrK = thr_mma.partition_fragment_B(sK(_, _, _0{}));
  531. flash::gemm_sm80<Q_in_regs>(
  532. tSrS, tSrQ_cur, tSrK, tSsQ, tSsK(_, _, _, kStages > 1 ? smem_pipe_read : 0),
  533. tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K, load_V_next
  534. );
  535. smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0;
  536. scoremod_premask_fn(tSrS);
  537. // Faster to load_K before gemm if we only have 1 stage
  538. if constexpr (kStages == 1) { sync(); load_K_next(); }
  539. mask_fn(tSrS, n_block);
  540. Tensor scores_scale = softmax.template max_get_scale</*Is_first=*/Is_first_iter, Check_inf>(tSrS);
  541. softmax.template online_softmax</*Is_first=*/Is_first_iter, Check_inf>(tSrS);
  542. if constexpr (Is_FP8) { flash::permute_Cregs_fp8(tSrS); }
  543. Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<TiledMma>(tSrS.layout()));
  544. Tensor tOrP = make_tensor_like<Element>(tOrP_acc);
  545. convert_type_out(tOrP_acc, tOrP);
  546. if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); }
  547. if constexpr (kStages > 1) { sync(); }
  548. Tensor tOrV = thr_mma.partition_fragment_B(sVt(_, _, _0{}));
  549. flash::gemm_rs_sm80(tOrO, tOrP, tOrV, tOsVt(_, _, _, kStages > 1 ? smem_pipe_read : 0), tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
  550. if constexpr (kStages > 1) { load_K_next(); }
  551. smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0;
  552. };
  553. auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
  554. fwd_step(n_block, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/);
  555. --n_block;
  556. if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking
  557. auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
  558. int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM);
  559. int const n_block_min_causal_local_mask =
  560. std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN);
  561. #pragma unroll 1
  562. for (; n_block >= n_block_min_causal_local_mask; --n_block) {
  563. fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/);
  564. }
  565. }
  566. int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;
  567. int const n_block_min_before_local_mask = !Is_local
  568. ? n_block_min
  569. : std::max(n_block_min,
  570. cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN));
  571. auto no_mask_fn = [](auto& tSrS, int n_block) { };
  572. #pragma unroll 1
  573. for (; n_block >= n_block_min_before_local_mask; --n_block) {
  574. fwd_step(n_block, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/);
  575. }
  576. // Separate masking iterations on the left for local attention
  577. if constexpr (Is_local) {
  578. auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, false /*Causal_mask*/, Is_local>(tSrS, m_block, n_block); };
  579. #pragma unroll 1
  580. for (; n_block >= n_block_min; --n_block) {
  581. fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant<Is_local>{} /*check_inf*/);
  582. }
  583. }
  584. float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)];
  585. Tensor scores_scale = softmax.finalize(v_descale);
  586. softmax.rescale_o(tOrO, scores_scale);
  587. if constexpr (Is_FP8) { flash::permute_output_fp8(tOrO); }
  588. return true;
  589. }
  590. template <typename SharedStorage>
  591. CUTLASS_DEVICE bool
  592. store_kv_new(Params const& params,
  593. int const thread_idx,
  594. SharedStorage &shared_storage,
  595. SeqlenInfo_t const& seqlen_info,
  596. cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord
  597. ) {
  598. auto [m_block, bidh, bidb, split_idx] = block_coord;
  599. auto n_block_new_min_max = BlockMN_t::get_n_block_k_new_min_max(
  600. seqlen_info, m_block, bidb, split_idx, params.num_splits,
  601. params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod);
  602. int const n_block_new_min = get<0>(n_block_new_min_max);
  603. int const n_block_new_max = get<1>(n_block_new_min_max);
  604. if (n_block_new_max <= n_block_new_min) { return false; }
  605. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  606. Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});
  607. int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;
  608. int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb];
  609. bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new;
  610. Tensor mKnew = make_tensor(make_gmem_ptr(params.ptr_K_new), params.shape_K_new, params.stride_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0);
  611. Tensor mVnew = make_tensor(make_gmem_ptr(params.ptr_V_new), params.shape_K_new, params.stride_V_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0);
  612. bool const is_varlen_k = Varlen && params.cu_seqlens_k;
  613. Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);
  614. Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);
  615. Tensor gKnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  616. Tensor gVnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mVnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  617. int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og;
  618. Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  619. Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  620. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  621. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  622. int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k;
  623. int const seqlen_k_new = seqlen_info.seqlen_k_new;
  624. using Rotary_t = Rotary<kBlockN, kHeadDim, NumMmaThreads, Element>;
  625. Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos,
  626. params.ptr_rotary_sin, params.stride_rotary_sin,
  627. params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary);
  628. using PagedKVManager_t = PagedKVManager<get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>;
  629. PagedKVManager_t paged_kv_manager(
  630. params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable,
  631. params.ptr_K, params.shape_K, params.stride_K,
  632. params.ptr_V, params.headdim_v, params.stride_V,
  633. params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k
  634. // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position
  635. );
  636. static_assert(std::is_same_v<GmemLayoutAtomAppend, typename Rotary_t::LayoutAtom>);
  637. static_assert(!PagedKV || std::is_same_v<GmemLayoutAtomAppend, typename PagedKVManager_t::GmemLayoutAtomKVCpAsync>);
  638. GmemTiledCopyQKV gmem_tiled_copy_kv_g2s;
  639. auto gmem_thr_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(thread_idx);
  640. auto gmem_thr0_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(_0{}); // Only for index calculation
  641. GmemTiledCopyAppendKV gmem_tiled_copy_kv_s2g;
  642. auto gmem_thr_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(thread_idx);
  643. auto gmem_thr0_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(_0{}); // Only for index calculation
  644. Tensor tKgKnew = gmem_thr_copy_kv_g2s.partition_S(gKnew);
  645. Tensor tKsKg2s = gmem_thr_copy_kv_g2s.partition_S(sK);
  646. Tensor tKsKs2g = gmem_thr_copy_kv_s2g.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  647. Tensor tKgK = gmem_thr_copy_kv_s2g.partition_D(gK);
  648. Tensor tVgVnew = gmem_thr_copy_kv_g2s.partition_S(gVnew); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  649. Tensor tVsVg2s = gmem_thr_copy_kv_g2s.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  650. Tensor tVsVs2g = gmem_thr_copy_kv_s2g.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  651. Tensor tVgV = gmem_thr_copy_kv_s2g.partition_D(gV);
  652. Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k)
  653. Tensor tKcKg2s = gmem_thr_copy_kv_g2s.partition_D(cK);
  654. Tensor t0KcKg2s = gmem_thr0_copy_kv_g2s.partition_D(cK);
  655. Tensor tKpKg2s = make_tensor<bool>(make_shape(size<2>(tKsKg2s)));
  656. Tensor tKcKs2g = gmem_thr_copy_kv_s2g.partition_D(cK);
  657. Tensor t0KcKs2g = gmem_thr0_copy_kv_s2g.partition_D(cK);
  658. Tensor tKpKs2g = make_tensor<bool>(make_shape(size<2>(tKsKs2g)));
  659. #pragma unroll
  660. for (int k = 0; k < size(tKpKg2s); ++k) { tKpKg2s(k) = get<1>(tKcKg2s(_0{}, _0{}, k)) < get<1>(params.shape_K); }
  661. #pragma unroll
  662. for (int k = 0; k < size(tKpKs2g); ++k) { tKpKs2g(k) = get<1>(tKcKs2g(_0{}, _0{}, k)) < get<1>(params.shape_K); }
  663. auto load_K_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) {
  664. static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;
  665. static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;
  666. Tensor tKsK_cur = tKsKg2s(_, _, _, smem_pipe_write);
  667. int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN
  668. ? seqlen_k_new - n_block * kBlockN
  669. : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN)));
  670. // We don't need to clear the sK smem tiles since we won't write them out
  671. flash::copy</*Is_even_MN=*/!Seqlenk_mask && EvenN, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/true>(
  672. gmem_tiled_copy_kv_g2s, tKgKnew(_, _, _, n_block), tKsK_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit);
  673. };
  674. auto load_V_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) {
  675. static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;
  676. static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;
  677. Tensor tVsV_cur = tVsVg2s(_, _, _, smem_pipe_write);
  678. int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN
  679. ? seqlen_k_new - n_block * kBlockN
  680. : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN)));
  681. // We don't need to clear the sV smem tiles since we won't write them out
  682. flash::copy</*Is_even_MN=*/!Seqlenk_mask && EvenN, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/true>(
  683. gmem_tiled_copy_kv_g2s, tVgVnew(_, _, _, n_block), tVsV_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit);
  684. };
  685. auto store_K = [&] (int const n_block, int const smem_pipe_read) {
  686. int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN);
  687. if (get<1>(params.shape_rotary) <= 0) {
  688. Tensor tKsK_cur = tKsKs2g(_, _, _, smem_pipe_read);
  689. if constexpr (!PagedKV) {
  690. Tensor tKgK_cur = tKgK(_, _, _, n_block);
  691. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  692. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  693. gmem_tiled_copy_kv_s2g, tKsK_cur, tKgK_cur, tKcKs2g, tKpKs2g, std::min(seqlen_k_new - n_block * kBlockN, kBlockN)
  694. );
  695. } else {
  696. paged_kv_manager.store_K(n_block, tKsK_cur);
  697. }
  698. } else {
  699. Tensor gK_cur = gK(_, _, n_block);
  700. auto tPrKPtr = cute::conditional_return<PagedKV>(paged_kv_manager.compute_K_ptr(), nullptr);
  701. if (params.is_rotary_interleaved) {
  702. auto [tRrCos, tRrSin] = rotary.template load_cos_sin<true /*kInterleaved*/>(n_block);
  703. rotary.template apply_K_interleaved<PagedKV>(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCos, tRrSin, tPrKPtr, n_block);
  704. } else {
  705. auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin<false /*kInterleaved*/>(n_block);
  706. rotary.template apply_K_contiguous<PagedKV>(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K));
  707. }
  708. }
  709. };
  710. auto store_V = [&] (int const n_block, int const smem_pipe_read) {
  711. int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN);
  712. Tensor tVsV_cur = tVsVs2g(_, _, _, smem_pipe_read);
  713. if constexpr (!PagedKV) {
  714. Tensor tVgV_cur = tVgV(_, _, _, n_block);
  715. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  716. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  717. gmem_tiled_copy_kv_s2g, tVsV_cur, tVgV_cur, tKcKs2g, tKpKs2g, n_limit);
  718. } else {
  719. paged_kv_manager.store_V(n_block, tVsV_cur);
  720. }
  721. };
  722. int n_block = n_block_new_max - 1;
  723. // Note, using the for_each() function here to ensure `stage` is of type Int<x>.
  724. for_each(make_int_sequence<kStages>{}, [&] (auto stage) {
  725. static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0;
  726. static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1;
  727. if (Is_first_stage || n_block - stage >= n_block_new_min) {
  728. load_K_new(n_block - stage, stage, cute::bool_constant<Is_first_stage>{} /*Seqlenk_mask*/);
  729. }
  730. cute::cp_async_fence();
  731. // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v
  732. if constexpr (Is_first_stage) { __syncthreads(); }
  733. if constexpr (!Is_last_stage) {
  734. if (Is_first_stage || n_block - stage >= n_block_new_min) {
  735. load_V_new(n_block - stage, stage, cute::bool_constant<Is_first_stage>{} /*Seqlenk_mask*/);
  736. }
  737. cute::cp_async_fence();
  738. }
  739. });
  740. int smem_pipe_read = 0, smem_pipe_write = kStages - 1;
  741. #pragma unroll 1
  742. for (; n_block >= n_block_new_min; --n_block) {
  743. if constexpr (PagedKV) { paged_kv_manager.template load_page_table<true /*Seqlenk_mask*/>(n_block); }
  744. flash::cp_async_wait<kStages * 2 - 2>();
  745. __syncthreads();
  746. store_K(n_block, kStages > 1 ? smem_pipe_read : 0);
  747. if (n_block - kStages + 1 >= n_block_new_min) {
  748. load_V_new(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant<kStages == 1>{} /*Seqlenk_mask*/);
  749. }
  750. cute::cp_async_fence();
  751. smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0;
  752. flash::cp_async_wait<kStages * 2 - 2>();
  753. __syncthreads();
  754. store_V(n_block, kStages > 1 ? smem_pipe_read : 0);
  755. smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0;
  756. if (n_block - kStages >= n_block_new_min) {
  757. load_K_new(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/);
  758. }
  759. cute::cp_async_fence();
  760. }
  761. return true;
  762. }
  763. };
  764. } // namespace flash