mainloop_fwd_sm80.hpp 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include "cute/tensor.hpp"
  10. #include "seqlen.h"
  11. #include "mask.h"
  12. #include "pack_gqa.h"
  13. #include "paged_kv.h"
  14. #include "rotary.h"
  15. #include "utils.h"
  16. namespace flash {
  17. using namespace cute;
  18. template <int kNWarps, int Stages, bool Q_in_regs, class TileShape_MNK_, class Element_, class ElementAccum_, class ArchTag_,
  19. bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool PagedKV_, bool AppendKV_,
  20. bool PackGQA_, bool Split_>
  21. struct CollectiveMainloopFwdSm80 {
  22. static constexpr int kStages = Stages;
  23. static_assert(kStages > 0, "kStages must be greater than 0");
  24. using TileShape_MNK = TileShape_MNK_;
  25. using Element = Element_;
  26. using ElementAccum = ElementAccum_;
  27. using ArchTag = ArchTag_;
  28. static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;;
  29. static constexpr bool Is_causal = Is_causal_;
  30. static constexpr bool Is_local = Is_local_;
  31. static constexpr bool Has_softcap = Has_softcap_;
  32. static constexpr bool Varlen = Varlen_;
  33. static constexpr bool PagedKV = PagedKV_;
  34. static constexpr bool AppendKV = AppendKV_;
  35. static constexpr bool PackGQA = PackGQA_;
  36. static constexpr bool Split = Split_;
  37. static constexpr bool Transpose_V = Is_FP8;
  38. using SeqlenInfo_t = flash::SeqlenInfoQKNewK<Varlen, AppendKV>;
  39. static_assert(ArchTag::kMinComputeCapability >= 80);
  40. static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;
  41. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  42. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  43. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  44. using MMA_Atom_Arch = std::conditional_t<
  45. ArchTag::kMinComputeCapability >= 80,
  46. std::conditional_t<
  47. std::is_same_v<Element, cutlass::half_t>,
  48. MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
  49. MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
  50. >,
  51. MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>
  52. >;
  53. using TiledMma = TiledMMA<
  54. MMA_Atom_Arch,
  55. Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
  56. Tile<Int<16 * kNWarps>, _16, _16>>;
  57. static constexpr int NumMmaThreads = size(TiledMma{});
  58. static constexpr int NumProducerThreads = NumMmaThreads; // For compatibility with TileScheduler
  59. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  60. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  61. // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each
  62. // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.
  63. static constexpr int kBytePerRow = kHeadDim * sizeof(Element);
  64. static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
  65. static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
  66. static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
  67. using SmemLayoutAtomQKV = decltype(
  68. composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
  69. Layout<Shape<_8, Int<kBlockKGmem>>,
  70. Stride<Int<kBlockKGmem>, _1>>{}));
  71. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQKV{}, select<0, 2>(TileShape_MNK{})));
  72. using SmemLayoutK = decltype(tile_to_shape(
  73. SmemLayoutAtomQKV{},
  74. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  75. using SmemLayoutV = decltype(tile_to_shape(
  76. SmemLayoutAtomQKV{},
  77. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  78. using SmemLayoutVt = decltype(
  79. composition(SmemLayoutV{},
  80. make_ordered_layout(make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{}),
  81. Step<_2, _1, _3>{})));
  82. using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, Element>;
  83. using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, Element>;
  84. // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
  85. // from the same address by the same threadblock. This is slightly faster.
  86. using GmemCopyAtom = Copy_Atom<std::conditional_t<
  87. Has_cp_async,
  88. SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>,
  89. AutoVectorizingCopyWithAssumedAlignment<128>
  90. >, Element>;
  91. static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
  92. static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow");
  93. using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  94. Stride<Int<kGmemThreadsPerRow>, _1>>;
  95. using GmemTiledCopyQKV = decltype(
  96. make_tiled_copy(GmemCopyAtom{},
  97. GmemLayoutAtom{},
  98. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per read
  99. // So that we don't have to check if we overshot kBlockM when we load Q
  100. static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);
  101. // For AppendKV, We want each thread to have at least 2 loads in the K direction since in the case of
  102. // non-interleaved rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc),
  103. // each thread will load twice from the same row.
  104. static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element);
  105. static constexpr int kBlockKGmemAppend = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
  106. static constexpr int kGmemThreadsPerRowAppend = kBlockKGmemAppend / kGmemElemsPerLoad;
  107. static_assert(NumMmaThreads % kGmemThreadsPerRowAppend == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRowAppend");
  108. // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where
  109. // these threads share the same page table entry and share the work of computing pointers to paged K and paged V.
  110. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRowAppend == 0, "kGmemThreadsPerRowAppend must divide NumThreadsPerWarp");
  111. using GmemLayoutAtomAppend = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRowAppend>, Int<kGmemThreadsPerRowAppend>>,
  112. Stride<Int<kGmemThreadsPerRowAppend>, _1>>;
  113. // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication
  114. static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtomAppend{})) == 0, "kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRowAppend");
  115. using GmemTiledCopyAppendKV = decltype(
  116. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  117. GmemLayoutAtomAppend{},
  118. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
  119. using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
  120. using StrideQK = cute::Stride<int64_t, _1, int64_t, int64_t>;
  121. using StrideV = StrideQK;
  122. // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
  123. using ShapeQPacked = std::conditional_t<!PackGQA, ShapeQKV, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
  124. using StrideQPacked = std::conditional_t<!PackGQA, StrideQK, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t>>;
  125. using ShapePageTable = cute::Shape<int32_t, int32_t>; // (batch, max_num_pages_per_seq)
  126. using StridePageTable = cute::Stride<int64_t, _1>;
  127. using ShapeRotary = cute::Shape<int32_t, int32_t>; // (seqlen_ro, rotary_dim // 2)
  128. using StrideRotary = cute::Stride<int64_t, _1>;
  129. using StrideDescale = cute::Stride<int64_t, int64_t>;
  130. static constexpr bool Share_QV_Smem = Q_in_regs;
  131. struct TensorStorageSharedQV : cute::aligned_struct<128> {
  132. union {
  133. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  134. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  135. };
  136. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  137. };
  138. struct TensorStorageSeparateQV : cute::aligned_struct<128> {
  139. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  140. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  141. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  142. };
  143. using TensorStorage = std::conditional_t<Share_QV_Smem, TensorStorageSharedQV, TensorStorageSeparateQV>;
  144. // Host side kernel arguments
  145. struct Arguments {
  146. Element const* const ptr_Q;
  147. ShapeQKV const shape_Q;
  148. StrideQK const stride_Q;
  149. Element* const ptr_K; // Not Element const* since we might append to KV cache in-place
  150. ShapeQKV const shape_K;
  151. StrideQK const stride_K;
  152. Element* const ptr_V;
  153. StrideV const stride_V;
  154. Element const* const ptr_K_new;
  155. ShapeQKV const shape_K_new;
  156. StrideQK const stride_K_new;
  157. Element const* const ptr_V_new;
  158. StrideV const stride_V_new;
  159. Element const* const ptr_rotary_cos;
  160. ShapeRotary const shape_rotary;
  161. StrideRotary const stride_rotary_cos;
  162. Element const* const ptr_rotary_sin;
  163. StrideRotary const stride_rotary_sin;
  164. bool const is_rotary_interleaved;
  165. int const* const ptr_pagetable;
  166. ShapePageTable const shape_pagetable;
  167. StridePageTable const stride_pagetable;
  168. float const softmax_scale;
  169. float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale;
  170. StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale;
  171. int const window_size_left = -1, window_size_right = -1, sink_token_length = 0;
  172. float const softcap_val;
  173. int const num_splits;
  174. int const* const kv_batch_idx = nullptr;
  175. int const* const cu_seqlens_q = nullptr;
  176. int const* const cu_seqlens_k = nullptr;
  177. int const* const cu_seqlens_k_new = nullptr;
  178. int const* const seqused_q = nullptr;
  179. int const* const seqused_k = nullptr;
  180. int const* const leftpad_k = nullptr;
  181. };
  182. // Device side kernel params
  183. struct Params {
  184. Element const* const ptr_Q;
  185. ShapeQKV const shape_Q;
  186. StrideQK const stride_Q;
  187. ShapeQPacked const shape_Q_packed;
  188. StrideQPacked const stride_Q_packed;
  189. Element* const ptr_K;
  190. ShapeQKV const shape_K;
  191. StrideQK const stride_K;
  192. Element* const ptr_V;
  193. StrideV const stride_V;
  194. Element const* const ptr_K_new;
  195. ShapeQKV const shape_K_new;
  196. StrideQK const stride_K_new;
  197. Element const* const ptr_V_new;
  198. StrideV const stride_V_new;
  199. Element const* const ptr_rotary_cos;
  200. ShapeRotary const shape_rotary;
  201. StrideRotary const stride_rotary_cos;
  202. Element const* const ptr_rotary_sin;
  203. StrideRotary const stride_rotary_sin;
  204. bool const is_rotary_interleaved;
  205. int const* const ptr_pagetable;
  206. ShapePageTable const shape_pagetable;
  207. StridePageTable const stride_pagetable;
  208. cutlass::FastDivmod page_size_divmod;
  209. cutlass::FastDivmod qhead_per_khead_divmod;
  210. float const softmax_scale_log2;
  211. float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale;
  212. StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale;
  213. float const softcap_val;
  214. int const window_size_left, window_size_right, sink_token_length;
  215. int const num_splits;
  216. int const* const kv_batch_idx = nullptr;
  217. int const* const cu_seqlens_q = nullptr;
  218. int const* const cu_seqlens_k = nullptr;
  219. int const* const cu_seqlens_k_new = nullptr;
  220. int const* const seqused_q = nullptr;
  221. int const* const seqused_k = nullptr;
  222. int const* const leftpad_k = nullptr;
  223. };
  224. static Params
  225. to_underlying_arguments(Arguments const& args) {
  226. // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size)
  227. int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K));
  228. auto const shape_Q_packed = cute::conditional_return<!PackGQA>(
  229. args.shape_Q,
  230. make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q))
  231. );
  232. auto const stride_Q_packed = cute::conditional_return<!PackGQA>(
  233. args.stride_Q,
  234. make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q))
  235. );
  236. if (get<1>(args.shape_rotary) > 0) {
  237. assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr);
  238. }
  239. assert(args.num_splits >= 1);
  240. // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val.
  241. // Right after this, we multiply by log2(e) before applying exp2.
  242. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val
  243. // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e)
  244. // (assigning it to params.softmax_scale_log2).
  245. return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed,
  246. args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.stride_V,
  247. args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new,
  248. args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos,
  249. args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved,
  250. args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable,
  251. cutlass::FastDivmod(int(get<0>(args.shape_K))),
  252. cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
  253. !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E),
  254. args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale,
  255. args.stride_q_descale, args.stride_k_descale, args.stride_v_descale,
  256. !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val,
  257. args.window_size_left, args.window_size_right, args.sink_token_length,
  258. !Split ? 1 : args.num_splits,
  259. args.kv_batch_idx,
  260. args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new,
  261. args.seqused_q, args.seqused_k, args.leftpad_k};
  262. }
  263. CUTLASS_DEVICE
  264. cute::tuple<int, int> get_n_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info,
  265. int m_block, int bidb, int split_idx=0, int num_splits=1) {
  266. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  267. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  268. int const seqlen_k = seqlen_info.seqlen_k;
  269. int const seqlen_q = seqlen_info.seqlen_q;
  270. int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
  271. if constexpr (Is_causal || Is_local) {
  272. int m_idx_max = (m_block + 1) * kBlockM;
  273. if (PackGQA) { m_idx_max = params.qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
  274. n_block_max = std::min(n_block_max,
  275. cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + params.window_size_right, kBlockN));
  276. }
  277. int n_block_min = 0;
  278. if constexpr (Is_local) {
  279. int m_idx_min = m_block * kBlockM;
  280. if (PackGQA) { m_idx_min = params.qhead_per_khead_divmod.divide(m_idx_min); }
  281. n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - params.window_size_left) / kBlockN);
  282. }
  283. // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
  284. if constexpr (Split) {
  285. int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits);
  286. n_block_min = n_block_min + split_idx * num_n_blocks_per_split;
  287. n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);
  288. }
  289. // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
  290. return {n_block_min, n_block_max};
  291. }
  292. template <typename SharedStorage, typename FrgTensorO, typename Softmax>
  293. CUTLASS_DEVICE bool
  294. mma(Params const& params,
  295. FrgTensorO& tOrO,
  296. Softmax& softmax,
  297. int const thread_idx,
  298. SeqlenInfo_t const& seqlen_info,
  299. cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
  300. SharedStorage& shared_storage
  301. ) {
  302. static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
  303. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  304. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  305. // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda
  306. int const m_block = get<0>(block_coord);
  307. int const bidh = get<1>(block_coord);
  308. int const bidb = get<2>(block_coord);
  309. int const split_idx = get<3>(block_coord);
  310. int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;
  311. auto n_block_min_max = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits);
  312. int const n_block_min = get<0>(n_block_min_max);
  313. int const n_block_max = get<1>(n_block_min_max);
  314. // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier
  315. if constexpr (Is_causal || Is_local || Varlen || Split) {
  316. if (n_block_max <= n_block_min) { return false; }
  317. }
  318. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});
  319. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  320. Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});
  321. Tensor sVt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{});
  322. bool const is_varlen_q = Varlen && params.cu_seqlens_q;
  323. bool const is_varlen_k = Varlen && params.cu_seqlens_k;
  324. int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb];
  325. Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0);
  326. Tensor gQ = local_tile(mQ, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
  327. Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K + seqlen_info.offset_k * get<0>(params.stride_K)), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);
  328. Tensor gK = local_tile(mK, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  329. Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V + seqlen_info.offset_k * get<0>(params.stride_V)), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);
  330. Tensor gV = local_tile(mV, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  331. GmemTiledCopyQKV gmem_tiled_copy_QKV;
  332. auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx);
  333. auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{}); // For index calculation
  334. Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN)
  335. Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
  336. Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN)
  337. Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
  338. TiledMma tiled_mma;
  339. auto thr_mma = tiled_mma.get_slice(thread_idx);
  340. // Allocate "fragments/descriptors"
  341. Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
  342. // Copy Atom retiling
  343. auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
  344. auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx);
  345. auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);
  346. auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(thread_idx);
  347. auto smem_tiled_copy_V = make_tiled_copy_B(SmemCopyAtomTransposed{}, tiled_mma);
  348. auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx);
  349. Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
  350. Tensor tSsK = smem_thr_copy_K.partition_S(sK);
  351. Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
  352. // Predicates
  353. Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{}));
  354. Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);
  355. Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV);
  356. Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
  357. #pragma unroll
  358. for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); }
  359. int const seqlen_q = seqlen_info.seqlen_q;
  360. int const seqlen_k = seqlen_info.seqlen_k;
  361. int n_block = n_block_max - 1;
  362. // Prologue: load Q, K, V
  363. // If persistent, we don't need to wait for the previous work_idx to finish
  364. // since we assume that all MMA threads sync in the epilogue before writing to smem_o.
  365. // So any thread gets there, all threads must have finished the previous MMA and at least started
  366. // writing to smem_o.
  367. // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v
  368. if constexpr (Share_QV_Smem) { __syncthreads(); }
  369. if constexpr (!PackGQA) {
  370. Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
  371. Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
  372. Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
  373. Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);
  374. Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ);
  375. Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
  376. #pragma unroll
  377. for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); }
  378. // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit
  379. // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time.
  380. // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
  381. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/true>(
  382. gmem_tiled_copy_QKV, tQgQ, tQsQ, t0QcQ, tQpQ, seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{}))
  383. );
  384. } else {
  385. using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element>;
  386. PackGQAt::load_Q(mQ, sQ, params.qhead_per_khead_divmod, thread_idx, seqlen_q, m_block);
  387. }
  388. cute::cp_async_fence();
  389. using PagedKVManager_t = PagedKVManager<get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>;
  390. PagedKVManager_t paged_kv_manager(
  391. params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable,
  392. params.ptr_K, params.shape_K, params.stride_K,
  393. params.ptr_V, params.stride_V,
  394. params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k
  395. );
  396. auto load_K = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) {
  397. static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;
  398. if constexpr (!PagedKV) {
  399. // Do we need bound check to make sure the row doesn't go above kBlockN
  400. static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;
  401. Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_write);
  402. // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit
  403. // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time.
  404. int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN
  405. ? seqlen_info.seqlen_k - n_block * kBlockN
  406. : (!Seqlenk_mask ? kBlockN : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN)));
  407. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
  408. flash::copy</*Is_even_MN=*/!Seqlenk_mask && EvenN, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/true>(
  409. gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK_cur, t0KVcKV, tKVpKV, seqlenk_row_limit);
  410. } else {
  411. paged_kv_manager.template load_page_table<Seqlenk_mask>(n_block);
  412. paged_kv_manager.template load_K<Seqlenk_mask>(n_block, sK(_, _, smem_pipe_write));
  413. }
  414. };
  415. auto load_V = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) {
  416. static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;
  417. if constexpr (!PagedKV) {
  418. // Do we need bound check to make sure the row doesn't go above kBlockN
  419. static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;
  420. Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_write);
  421. // We don't call flash::copy since it doesn't support bound checking
  422. // to not overshot kBlockN when writing to smem.
  423. Tensor tVgV_cur = tVgV(_, _, _, n_block);
  424. int const seqlenk_row_limit = seqlen_info.seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{}));
  425. #pragma unroll
  426. for (int m = 0; m < size<1>(tVsV); ++m) {
  427. // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked
  428. if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) {
  429. bool const predicate_n = !Seqlenk_mask || get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit;
  430. #pragma unroll
  431. for (int k = 0; k < size<2>(tVsV); ++k) {
  432. cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV_cur(_, m, k), tVsV_cur(_, m, k));
  433. }
  434. }
  435. }
  436. } else {
  437. paged_kv_manager.template load_V<Seqlenk_mask>(n_block, sV(_, _, smem_pipe_write));
  438. }
  439. };
  440. auto preprocess_Q = [&] {
  441. if constexpr (!AppendKV) {
  442. flash::cp_async_wait<Share_QV_Smem ? 1 : kStages * 2 - 1>();
  443. } else {
  444. if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q
  445. int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k;
  446. using Rotary_t = Rotary<kBlockM, kHeadDim, NumMmaThreads, Element, !(Is_causal || Is_local) /*FixedPosition*/>;
  447. Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos,
  448. params.ptr_rotary_sin, params.stride_rotary_sin,
  449. params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary);
  450. int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
  451. if (params.is_rotary_interleaved) {
  452. auto [tRrCos, tRrSin] = cute::conditional_return<!PackGQA>(
  453. rotary.template load_cos_sin<true /*kInterleaved*/>(m_block),
  454. rotary.template load_cos_sin_packgqa<true /*kInterleaved*/>(m_block, params.qhead_per_khead_divmod)
  455. );
  456. flash::cp_async_wait<Share_QV_Smem ? 1 : kStages * 2 - 1>();
  457. __syncthreads();
  458. rotary.apply_Q_interleaved(sQ, tRrCos, tRrSin, m_block, qhead_per_khead);
  459. } else {
  460. auto [tRrCosCont, tRrSinCont] = cute::conditional_return<!PackGQA>(
  461. rotary.template load_cos_sin<false /*kInterleaved*/>(m_block),
  462. rotary.template load_cos_sin_packgqa<false /*kInterleaved*/>(m_block, params.qhead_per_khead_divmod)
  463. );
  464. flash::cp_async_wait<Share_QV_Smem ? 1 : kStages * 2 - 1>();
  465. __syncthreads();
  466. rotary.apply_Q_contiguous(sQ, tRrCosCont, tRrSinCont, m_block, qhead_per_khead);
  467. }
  468. } else {
  469. flash::cp_async_wait<Share_QV_Smem ? 1 : kStages * 2 - 1>();
  470. }
  471. }
  472. if constexpr (Q_in_regs) {
  473. __syncthreads();
  474. Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
  475. Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(sQ);
  476. cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view);
  477. }
  478. };
  479. // If Share_QV_Smem, we load Q, then load 1 stage of K, then (optionally) rotate Q and
  480. // read from smem_q to registers, then load V.
  481. // If !Share_QV, Smem, we load Q, load all stages of K & V, then (optionally) rotate Q.
  482. if constexpr (Share_QV_Smem) {
  483. load_K(n_block, 0, cute::true_type{} /*Seqlenk_mask*/);
  484. cute::cp_async_fence();
  485. preprocess_Q();
  486. __syncthreads(); // Make sure all threads have read smem_q before loading V
  487. }
  488. // For persistent, make sure all threads have finished reading smem_o
  489. if constexpr (!Share_QV_Smem) { __syncthreads(); }
  490. // Note, using the for_each() function here to ensure `stage` is of type Int<x>.
  491. for_each(make_int_sequence<kStages>{}, [&] (auto stage) {
  492. static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0;
  493. static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1;
  494. if constexpr (!Share_QV_Smem || !Is_first_stage) {
  495. if (Is_first_stage || n_block - stage >= n_block_min) {
  496. load_K(n_block - stage, stage, cute::bool_constant<Is_first_stage>{} /*Seqlenk_mask*/);
  497. }
  498. // We want the fence outside the if statement to have a fixed number of cp.async commits.
  499. // so that we can wait with the correct number of outstanding commits.
  500. cute::cp_async_fence();
  501. }
  502. if constexpr (!Is_last_stage) {
  503. if (Is_first_stage || n_block - stage >= n_block_min) {
  504. load_V(n_block - stage, stage, cute::bool_constant<Is_first_stage>{} /*Seqlenk_mask*/);
  505. }
  506. cute::cp_async_fence();
  507. }
  508. });
  509. if constexpr (!Share_QV_Smem) { preprocess_Q(); }
  510. flash::Mask<kBlockM, kBlockN, PackGQA, TiledMma> mask(
  511. thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length,
  512. params.qhead_per_khead_divmod
  513. );
  514. float softcap_val = params.softcap_val;
  515. if constexpr (Has_softcap && Is_FP8) {
  516. float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)];
  517. float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)];
  518. softcap_val *= q_descale * k_descale;
  519. }
  520. // Softcapping needs to happen before masking since if we apply after masking, softcapping can turn
  521. // -inf to e.g. -50.0, which can affect the attention softmax.
  522. auto scoremod_premask_fn = [&](auto& tSrS) {
  523. if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); }
  524. };
  525. int smem_pipe_read = 0, smem_pipe_write = kStages - 1;
  526. auto load_K_next = [&] {
  527. if (n_block - kStages >= n_block_min) {
  528. load_K(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/);
  529. }
  530. cute::cp_async_fence();
  531. };
  532. auto sync = [&] {
  533. flash::cp_async_wait<kStages * 2 - 2>();
  534. __syncthreads();
  535. };
  536. clear(tOrO);
  537. auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) {
  538. static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value;
  539. static constexpr bool Check_inf = decltype(check_inf_type)::value;
  540. Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));
  541. clear(tSrS);
  542. sync();
  543. auto load_V_next = [&] {
  544. if (n_block - kStages + 1 >= n_block_min) {
  545. load_V(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant<Is_first_iter && kStages == 1>{} /*Seqlenk_mask*/);
  546. }
  547. cute::cp_async_fence();
  548. };
  549. Tensor tSrQ_cur = cute::conditional_return<Q_in_regs>(tSrQ, thr_mma.partition_fragment_A(sQ));
  550. Tensor tSrK = thr_mma.partition_fragment_B(sK(_, _, _0{}));
  551. flash::gemm_sm80<Q_in_regs>(
  552. tSrS, tSrQ_cur, tSrK, tSsQ, tSsK(_, _, _, kStages > 1 ? smem_pipe_read : 0),
  553. tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K, load_V_next
  554. );
  555. smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0;
  556. scoremod_premask_fn(tSrS);
  557. // Faster to load_K before gemm if we only have 1 stage
  558. if constexpr (kStages == 1) { sync(); load_K_next(); }
  559. mask_fn(tSrS, n_block);
  560. Tensor scores_scale = softmax.template max_get_scale</*Is_first=*/Is_first_iter, Check_inf>(tSrS);
  561. softmax.template online_softmax</*Is_first=*/Is_first_iter, Check_inf>(tSrS);
  562. if constexpr (Is_FP8) { flash::permute_Cregs_fp8(tSrS); }
  563. Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<TiledMma>(tSrS.layout()));
  564. Tensor tOrP = make_tensor_like<Element>(tOrP_acc);
  565. convert_type_out(tOrP_acc, tOrP);
  566. if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); }
  567. if constexpr (kStages > 1) { sync(); }
  568. Tensor tOrV = thr_mma.partition_fragment_B(sVt(_, _, _0{}));
  569. flash::gemm_rs_sm80(tOrO, tOrP, tOrV, tOsVt(_, _, _, kStages > 1 ? smem_pipe_read : 0), tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
  570. if constexpr (kStages > 1) { load_K_next(); }
  571. smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0;
  572. };
  573. auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
  574. fwd_step(n_block, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/);
  575. --n_block;
  576. if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking
  577. auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
  578. int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM);
  579. int const n_block_min_causal_local_mask =
  580. std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN);
  581. #pragma unroll 1
  582. for (; n_block >= n_block_min_causal_local_mask; --n_block) {
  583. fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/);
  584. }
  585. }
  586. int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;
  587. int const n_block_min_before_local_mask = !Is_local
  588. ? n_block_min
  589. : std::max(n_block_min,
  590. cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN));
  591. auto no_mask_fn = [](auto& tSrS, int n_block) { };
  592. #pragma unroll 1
  593. for (; n_block >= n_block_min_before_local_mask; --n_block) {
  594. fwd_step(n_block, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/);
  595. }
  596. // Separate masking iterations on the left for local attention
  597. if constexpr (Is_local) {
  598. auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, false /*Causal_mask*/, Is_local>(tSrS, m_block, n_block); };
  599. #pragma unroll 1
  600. for (; n_block >= n_block_min; --n_block) {
  601. fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant<Is_local>{} /*check_inf*/);
  602. }
  603. // Disable sink token code for now
  604. // int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN);
  605. // #pragma unroll 1
  606. // for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) {
  607. // fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant<Is_local>{} /*check_inf*/);
  608. // }
  609. }
  610. float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)];
  611. Tensor scores_scale = softmax.finalize(v_descale);
  612. softmax.rescale_o(tOrO, scores_scale);
  613. if constexpr (Is_FP8) { flash::permute_output_fp8(tOrO); }
  614. return true;
  615. }
  616. CUTLASS_DEVICE
  617. cute::tuple<int, int> get_n_block_k_new_min_max(Params const& params, SeqlenInfo_t const& seqlen_info,
  618. int m_block, int bidb, int split_idx=0, int num_splits=1) {
  619. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  620. auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, num_splits);
  621. int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
  622. int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
  623. int const n_block_new_min = idx_k_new_min / kBlockN;
  624. int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
  625. // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
  626. return {n_block_new_min, n_block_new_max};
  627. }
  628. template <typename SharedStorage>
  629. CUTLASS_DEVICE bool
  630. store_kv_new(Params const& params,
  631. int const thread_idx,
  632. SharedStorage &shared_storage,
  633. SeqlenInfo_t const& seqlen_info,
  634. cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord
  635. ) {
  636. auto [m_block, bidh, bidb, split_idx] = block_coord;
  637. auto n_block_new_min_max = get_n_block_k_new_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits);
  638. int const n_block_new_min = get<0>(n_block_new_min_max);
  639. int const n_block_new_max = get<1>(n_block_new_min_max);
  640. if (n_block_new_max <= n_block_new_min) { return false; }
  641. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  642. Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});
  643. int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;
  644. int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb];
  645. bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new;
  646. Tensor mKnew = make_tensor(make_gmem_ptr(params.ptr_K_new), params.shape_K_new, params.stride_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0);
  647. Tensor mVnew = make_tensor(make_gmem_ptr(params.ptr_V_new), params.shape_K_new, params.stride_V_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0);
  648. bool const is_varlen_k = Varlen && params.cu_seqlens_k;
  649. Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);
  650. Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);
  651. Tensor gKnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  652. Tensor gVnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mVnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  653. int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og;
  654. Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  655. Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
  656. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  657. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  658. int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k;
  659. int const seqlen_k_new = seqlen_info.seqlen_k_new;
  660. using Rotary_t = Rotary<kBlockN, kHeadDim, NumMmaThreads, Element>;
  661. Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos,
  662. params.ptr_rotary_sin, params.stride_rotary_sin,
  663. params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary);
  664. using PagedKVManager_t = PagedKVManager<get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>;
  665. PagedKVManager_t paged_kv_manager(
  666. params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable,
  667. params.ptr_K, params.shape_K, params.stride_K,
  668. params.ptr_V, params.stride_V,
  669. params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k
  670. // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position
  671. );
  672. static_assert(std::is_same_v<GmemLayoutAtomAppend, typename Rotary_t::LayoutAtom>);
  673. static_assert(!PagedKV || std::is_same_v<GmemLayoutAtomAppend, typename PagedKVManager_t::GmemLayoutAtomKVCpAsync>);
  674. GmemTiledCopyQKV gmem_tiled_copy_kv_g2s;
  675. auto gmem_thr_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(thread_idx);
  676. auto gmem_thr0_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(_0{}); // Only for index calculation
  677. GmemTiledCopyAppendKV gmem_tiled_copy_kv_s2g;
  678. auto gmem_thr_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(thread_idx);
  679. auto gmem_thr0_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(_0{}); // Only for index calculation
  680. Tensor tKgKnew = gmem_thr_copy_kv_g2s.partition_S(gKnew);
  681. Tensor tKsKg2s = gmem_thr_copy_kv_g2s.partition_S(sK);
  682. Tensor tKsKs2g = gmem_thr_copy_kv_s2g.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  683. Tensor tKgK = gmem_thr_copy_kv_s2g.partition_D(gK);
  684. Tensor tVgVnew = gmem_thr_copy_kv_g2s.partition_S(gVnew); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  685. Tensor tVsVg2s = gmem_thr_copy_kv_g2s.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  686. Tensor tVsVs2g = gmem_thr_copy_kv_s2g.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  687. Tensor tVgV = gmem_thr_copy_kv_s2g.partition_D(gV);
  688. Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k)
  689. Tensor tKcKg2s = gmem_thr_copy_kv_g2s.partition_D(cK);
  690. Tensor t0KcKg2s = gmem_thr0_copy_kv_g2s.partition_D(cK);
  691. Tensor tKpKg2s = make_tensor<bool>(make_shape(size<2>(tKsKg2s)));
  692. Tensor tKcKs2g = gmem_thr_copy_kv_s2g.partition_D(cK);
  693. Tensor t0KcKs2g = gmem_thr0_copy_kv_s2g.partition_D(cK);
  694. Tensor tKpKs2g = make_tensor<bool>(make_shape(size<2>(tKsKs2g)));
  695. #pragma unroll
  696. for (int k = 0; k < size(tKpKg2s); ++k) { tKpKg2s(k) = get<1>(tKcKg2s(_0{}, _0{}, k)) < get<1>(params.shape_K); }
  697. #pragma unroll
  698. for (int k = 0; k < size(tKpKs2g); ++k) { tKpKs2g(k) = get<1>(tKcKs2g(_0{}, _0{}, k)) < get<1>(params.shape_K); }
  699. auto load_K_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) {
  700. static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;
  701. static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;
  702. Tensor tKsK_cur = tKsKg2s(_, _, _, smem_pipe_write);
  703. int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN
  704. ? seqlen_k_new - n_block * kBlockN
  705. : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN)));
  706. // We don't need to clear the sK smem tiles since we won't write them out
  707. flash::copy</*Is_even_MN=*/!Seqlenk_mask && EvenN, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/true>(
  708. gmem_tiled_copy_kv_g2s, tKgKnew(_, _, _, n_block), tKsK_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit);
  709. };
  710. auto load_V_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) {
  711. static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;
  712. static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;
  713. Tensor tVsV_cur = tVsVg2s(_, _, _, smem_pipe_write);
  714. int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN
  715. ? seqlen_k_new - n_block * kBlockN
  716. : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN)));
  717. // We don't need to clear the sV smem tiles since we won't write them out
  718. flash::copy</*Is_even_MN=*/!Seqlenk_mask && EvenN, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/true>(
  719. gmem_tiled_copy_kv_g2s, tVgVnew(_, _, _, n_block), tVsV_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit);
  720. };
  721. auto store_K = [&] (int const n_block, int const smem_pipe_read) {
  722. int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN);
  723. if (get<1>(params.shape_rotary) <= 0) {
  724. Tensor tKsK_cur = tKsKs2g(_, _, _, smem_pipe_read);
  725. if constexpr (!PagedKV) {
  726. Tensor tKgK_cur = tKgK(_, _, _, n_block);
  727. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  728. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  729. gmem_tiled_copy_kv_s2g, tKsK_cur, tKgK_cur, tKcKs2g, tKpKs2g, std::min(seqlen_k_new - n_block * kBlockN, kBlockN)
  730. );
  731. } else {
  732. paged_kv_manager.store_K(n_block, tKsK_cur);
  733. }
  734. } else {
  735. Tensor gK_cur = gK(_, _, n_block);
  736. auto tPrKPtr = cute::conditional_return<PagedKV>(paged_kv_manager.compute_K_ptr(), nullptr);
  737. if (params.is_rotary_interleaved) {
  738. auto [tRrCos, tRrSin] = rotary.template load_cos_sin<true /*kInterleaved*/>(n_block);
  739. rotary.template apply_K_interleaved<PagedKV>(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCos, tRrSin, tPrKPtr, n_block);
  740. } else {
  741. auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin<false /*kInterleaved*/>(n_block);
  742. rotary.template apply_K_contiguous<PagedKV>(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K));
  743. }
  744. }
  745. };
  746. auto store_V = [&] (int const n_block, int const smem_pipe_read) {
  747. int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN);
  748. Tensor tVsV_cur = tVsVs2g(_, _, _, smem_pipe_read);
  749. if constexpr (!PagedKV) {
  750. Tensor tVgV_cur = tVgV(_, _, _, n_block);
  751. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  752. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  753. gmem_tiled_copy_kv_s2g, tVsV_cur, tVgV_cur, tKcKs2g, tKpKs2g, n_limit);
  754. } else {
  755. paged_kv_manager.store_V(n_block, tVsV_cur);
  756. }
  757. };
  758. int n_block = n_block_new_max - 1;
  759. // Note, using the for_each() function here to ensure `stage` is of type Int<x>.
  760. for_each(make_int_sequence<kStages>{}, [&] (auto stage) {
  761. static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0;
  762. static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1;
  763. if (Is_first_stage || n_block - stage >= n_block_new_min) {
  764. load_K_new(n_block - stage, stage, cute::bool_constant<Is_first_stage>{} /*Seqlenk_mask*/);
  765. }
  766. cute::cp_async_fence();
  767. // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v
  768. if constexpr (Is_first_stage) { __syncthreads(); }
  769. if constexpr (!Is_last_stage) {
  770. if (Is_first_stage || n_block - stage >= n_block_new_min) {
  771. load_V_new(n_block - stage, stage, cute::bool_constant<Is_first_stage>{} /*Seqlenk_mask*/);
  772. }
  773. cute::cp_async_fence();
  774. }
  775. });
  776. int smem_pipe_read = 0, smem_pipe_write = kStages - 1;
  777. #pragma unroll 1
  778. for (; n_block >= n_block_new_min; --n_block) {
  779. if constexpr (PagedKV) { paged_kv_manager.template load_page_table<true /*Seqlenk_mask*/>(n_block); }
  780. flash::cp_async_wait<kStages * 2 - 2>();
  781. __syncthreads();
  782. store_K(n_block, kStages > 1 ? smem_pipe_read : 0);
  783. if (n_block - kStages + 1 >= n_block_new_min) {
  784. load_V_new(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant<kStages == 1>{} /*Seqlenk_mask*/);
  785. }
  786. cute::cp_async_fence();
  787. smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0;
  788. flash::cp_async_wait<kStages * 2 - 2>();
  789. __syncthreads();
  790. store_V(n_block, kStages > 1 ? smem_pipe_read : 0);
  791. smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0;
  792. if (n_block - kStages >= n_block_new_min) {
  793. load_K_new(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/);
  794. }
  795. cute::cp_async_fence();
  796. }
  797. return true;
  798. }
  799. };
  800. } // namespace flash