paged_kv.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cute/tensor.hpp>
  6. #include "cutlass/fast_math.h" // For cutlass::FastDivmod
  7. #include "utils.h"
  8. namespace flash {
  9. using namespace cute;
  10. template <int kBlockN, int kHeadDim, int NumThreads, typename Element, bool KV_Same_Iter=false, int LoadsPerRow_LB=1>
  11. struct PagedKVManager {
  12. // If KV_Same_Iter=false, then we do load_page_table(0), load_K(0), load_page_table(1), load_K(1), load_V(0),
  13. // load_page_table(2), load_K(2), load_V(1), etc.
  14. // So we need to compute the V pointers for the previous iteration.
  15. // LoadsPerRow_LB is the lower bound on number of loads per row in the K direction. This is useful for
  16. // rotary where we want each thread to have at least 2 loads per row.
  17. // We use CpAsync for K and V if PagedKV, since TMA doesn't work there
  18. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  19. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  20. // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each
  21. // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.
  22. // In the case of PackGQA, this reduces the number of times we need to call divmod.
  23. static_assert(kHeadDim % LoadsPerRow_LB == 0, "Headdim must be a multiple of LoadsPerRow_LB");
  24. static constexpr int kBytePerRow = kHeadDim / LoadsPerRow_LB * sizeof(Element);
  25. static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
  26. static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
  27. static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow");
  28. // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where
  29. // these threads share the same page table entry and share the work of computing pointers to paged K and paged V.
  30. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp");
  31. using GmemCopyAtomCpAsync = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<uint128_t>, Element>;
  32. using GmemLayoutAtomKVCpAsync = Layout<Shape <Int<NumThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  33. Stride<Int<kGmemThreadsPerRow>, _1>>;
  34. using GmemTiledCopyKVCpAsync = decltype(
  35. make_tiled_copy(GmemCopyAtomCpAsync{},
  36. GmemLayoutAtomKVCpAsync{},
  37. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
  38. using GmemTiledCopyKVStore = decltype(
  39. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  40. GmemLayoutAtomKVCpAsync{},
  41. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
  42. using ShapeKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
  43. using StrideKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
  44. using ShapePageTable = cute::Shape<int32_t, int32_t>; // (batch, max_num_pages_per_seq)
  45. using StridePageTable = cute::Stride<int64_t, _1>;
  46. using TensorPageTable = decltype(make_tensor(make_gmem_ptr(static_cast<int const*>(nullptr)), ShapePageTable{}, StridePageTable{})(int(0), _));
  47. using TensorKV = decltype(make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeKV{}, StrideKV{})(_, _, int(0), _));
  48. using GmemThrCopyKVCpAsync = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)));
  49. using TensortKcK = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{})));
  50. using TensortKpK = decltype(make_tensor<bool>(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{}));
  51. // For PagedKV, it's expensive the calculate the pointers to K and V for each page table entry,
  52. // since those require int64_t arithmetic. We optimize by having threads split this work.
  53. // Typically there are 8 threads loading per row (e.g. hdim 64 and 128), and there are 11 rows
  54. // that each thread needs to load for the case of hdim 128 and kBlockN = 176.
  55. // So each of those 8 threads will calculate the K_ptr and V_ptr for 11 / 8 = 2 rows.
  56. // We then use __shfl_sync to broadcast the pointers to the other threads in the warp.
  57. static constexpr int kPageEntryPerThread = cute::ceil_div(size<1>(TensortKcK{}), kGmemThreadsPerRow);
  58. using TensorPageOffset = decltype(make_tensor<cute::tuple<int, int>>(Shape<Int<kPageEntryPerThread>>{}));
  59. using TensorKVPtr = decltype(make_tensor<Element*>(Shape<Int<kPageEntryPerThread>>{}));
  60. GmemTiledCopyKVCpAsync gmem_tiled_copy_kv;
  61. cutlass::FastDivmod const &page_size_divmod;
  62. int const thread_idx;
  63. int const seqlen_k;
  64. int const leftpad_k;
  65. GmemThrCopyKVCpAsync const gmem_thr_copy_kv;
  66. TensorPageTable mPageTable;
  67. TensorKV mK_paged, mV_paged;
  68. TensortKpK tKpK;
  69. TensorPageOffset tPrPageOffset;
  70. TensorKVPtr tPrVPtr;
  71. CUTLASS_DEVICE
  72. PagedKVManager(int const* const ptr_page_table,
  73. ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable,
  74. Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K,
  75. Element* const ptr_V, StrideKV const &stride_V,
  76. cutlass::FastDivmod const &page_size_divmod,
  77. int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k
  78. )
  79. : page_size_divmod(page_size_divmod)
  80. , thread_idx(thread_idx)
  81. , seqlen_k(seqlen_k)
  82. , leftpad_k(leftpad_k)
  83. , gmem_thr_copy_kv(gmem_tiled_copy_kv.get_thread_slice(thread_idx))
  84. {
  85. mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _);
  86. mK_paged = make_tensor(make_gmem_ptr(ptr_K), shape_K, stride_K)(_, _, bidh, _);
  87. mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_K, stride_V)(_, _, bidh, _);
  88. tKpK = make_tensor<bool>(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{});
  89. Tensor cK = cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k)
  90. Tensor tKcK = gmem_thr_copy_kv.partition_S(cK);
  91. #pragma unroll
  92. for (int k = 0; k < size<1>(tKpK); ++k) { tKpK(_0{}, k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(shape_K); }
  93. };
  94. template <bool Seqlenk_mask=false, bool First_iter=false>
  95. CUTLASS_DEVICE
  96. void load_page_table(const int n_block) {
  97. // The uncoalesced gmem load is intentional. This is so that each thread only loads the page table entries
  98. // it needs, and we don't need any sync between warps.
  99. // Assuming 8 threads per row, and 176 rows, then the rows from 0 to 175 are loaded by
  100. // threads 0, 8, 16, ..., 120, 1, 9, ..., 121, 2, 10, ..., 122, etc.
  101. #pragma unroll
  102. for (int i = 0; i < kPageEntryPerThread; ++i) {
  103. int const row = i * NumThreads + (thread_idx % kGmemThreadsPerRow) * (NumThreads / kGmemThreadsPerRow) + (thread_idx / kGmemThreadsPerRow);
  104. int const row_idx = n_block * kBlockN + row;
  105. int page_idx, page_offset;
  106. page_idx = page_size_divmod.divmod(page_offset, row_idx + leftpad_k);
  107. // Add the condition (i + 1) * NumThreads <= kBlockN since that is an upper bound of row
  108. // and is known at compile time. It avoids branching when e.g., kBlockN = 176 and i = 0.
  109. int const page = ((i + 1) * NumThreads <= kBlockN || row < kBlockN) && (!Seqlenk_mask || row_idx < seqlen_k) ? mPageTable[page_idx] : 0;
  110. tPrPageOffset[i] = {page, page_offset};
  111. // if (cute::thread0()) { printf("row = %d, page_idx = %d, page_offset = %d, page = %d, leftpad_k = %d, seqlen_k = %d\n", row, page_idx, page_offset, page, leftpad_k, seqlen_k); }
  112. }
  113. if constexpr (First_iter && !KV_Same_Iter) { compute_V_ptr(); }
  114. };
  115. CUTLASS_DEVICE
  116. TensorKVPtr compute_K_ptr() {
  117. Tensor tPrKPtr = make_tensor<Element*>(Shape<Int<kPageEntryPerThread>>{});
  118. #pragma unroll
  119. for (int i = 0; i < kPageEntryPerThread; ++i) {
  120. auto [page, page_offset] = tPrPageOffset[i];
  121. tPrKPtr[i] = &mK_paged(page_offset, _0{}, page);
  122. }
  123. return tPrKPtr;
  124. };
  125. CUTLASS_DEVICE
  126. void compute_V_ptr() {
  127. #pragma unroll
  128. for (int i = 0; i < kPageEntryPerThread; ++i) {
  129. auto [page, page_offset] = tPrPageOffset[i];
  130. tPrVPtr[i] = &mV_paged(page_offset, _0{}, page);
  131. }
  132. };
  133. template <bool Seqlenk_mask=false, typename TensorK>
  134. CUTLASS_DEVICE
  135. void load_K(const int n_block, TensorK &&sK) {
  136. // Do we need bound check to make sure the row doesn't go above kBlockN
  137. static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0;
  138. Tensor tPrKPtr = compute_K_ptr();
  139. // Only for index calculation, since all the indices of thread 0 are known at compile time
  140. auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{});
  141. Tensor tKsK = gmem_thr_copy_kv.partition_D(sK);
  142. Tensor cK = cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k)
  143. // Repeat the partitioning with identity layouts
  144. Tensor tKcK = gmem_thr_copy_kv.partition_S(cK);
  145. Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK);
  146. // We want to use the row indices of thread0 to compare, since that is known at compile time.
  147. // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{})))
  148. int const seqlenk_row_limit = -int(get<0>(tKcK(_0{}, _0{}, _0{}))) + (EvenN
  149. ? seqlen_k - n_block * kBlockN
  150. : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k - n_block * kBlockN, kBlockN)));
  151. #pragma unroll
  152. for (int m = 0; m < size<1>(tKsK); ++m) {
  153. bool const should_load = EvenN
  154. ? (!Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit)
  155. : get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit;
  156. Element const* k_ptr = reinterpret_cast<Element const*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow));
  157. Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape<Int<kHeadDim>>{});
  158. Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape<Int<kGmemElemsPerLoad>>{});
  159. if (should_load) {
  160. #pragma unroll
  161. for (int k = 0; k < size<2>(tKsK); ++k) {
  162. int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad;
  163. cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k)), mK_paged_cur_copy(_, ki), tKsK(_, m, k));
  164. }
  165. } // Don't need to clear out the rest of the smem since we'll mask out the scores anyway
  166. }
  167. };
  168. template <bool Seqlenk_mask=false, typename TensorV>
  169. CUTLASS_DEVICE
  170. void load_V(const int n_block, TensorV &&sV) {
  171. // Do we need bound check to make sure the row doesn't go above kBlockN
  172. static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0;
  173. if constexpr (KV_Same_Iter) { compute_V_ptr(); }
  174. // Only for index calculation, since all the indices of thread 0 are known at compile time
  175. auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{});
  176. Tensor tVsV = gmem_thr_copy_kv.partition_D(sV);
  177. Tensor cK = cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k)
  178. // Repeat the partitioning with identity layouts
  179. Tensor tKcK = gmem_thr_copy_kv.partition_S(cK);
  180. Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK);
  181. int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{}));
  182. #pragma unroll
  183. for (int m = 0; m < size<1>(tVsV); ++m) {
  184. // Faster to rely on the cp.async to clear smem that are out of bound,
  185. // rather than calling cute::clear directly.
  186. // We have to be careful not to write to smem past `kBlockN` if !EvenN.
  187. // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to checked
  188. if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKcK(_0{}, m, _0{})) < kBlockN) {
  189. bool const should_load = !Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit;
  190. Element const* v_ptr = reinterpret_cast<Element const*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));
  191. Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape<Int<kHeadDim>>{});
  192. Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape<Int<kGmemElemsPerLoad>>{});
  193. #pragma unroll
  194. for (int k = 0; k < size<2>(tVsV); ++k) {
  195. int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad;
  196. cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k));
  197. }
  198. }
  199. }
  200. if constexpr (!KV_Same_Iter) { compute_V_ptr(); }
  201. };
  202. template <typename TensorK>
  203. CUTLASS_DEVICE
  204. void store_K(const int n_block, TensorK &&tKrK) {
  205. Tensor tPrKPtr = compute_K_ptr();
  206. // We're using the same partitioning as GmemTiledCopyKVCpAsync (used for loading)
  207. // Only for index calculation, since all the indices of thread 0 are known at compile time
  208. auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{});
  209. Tensor cK = cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k)
  210. // Repeat the partitioning with identity layouts
  211. Tensor tKcK = gmem_thr_copy_kv.partition_S(cK);
  212. Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK);
  213. GmemTiledCopyKVStore gmem_tiled_copy_kv_store;
  214. // We want to use the row indices of thread0 to compare, since that is known at compile time.
  215. // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{})))
  216. // int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{}));
  217. int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{}));
  218. // if (threadIdx.x == 128) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_k = %d, seqlenk_row_limit = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_k, seqlenk_row_limit); }
  219. #pragma unroll
  220. for (int m = 0; m < size<1>(tKrK); ++m) {
  221. bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit;
  222. Element* k_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow));
  223. Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape<Int<kHeadDim>>{});
  224. Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape<Int<kGmemElemsPerLoad>>{});
  225. if (should_load) {
  226. #pragma unroll
  227. for (int k = 0; k < size<2>(tKrK); ++k) {
  228. int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad;
  229. if (tKpK(_0{}, k)) {
  230. cute::copy(gmem_tiled_copy_kv_store, tKrK(_, m, k), mK_paged_cur_copy(_, ki));
  231. }
  232. }
  233. }
  234. }
  235. };
  236. template <typename TensorV>
  237. CUTLASS_DEVICE
  238. void store_V(const int n_block, TensorV &&tVrV) {
  239. if constexpr (KV_Same_Iter) { compute_V_ptr(); }
  240. // Only for index calculation, since all the indices of thread 0 are known at compile time
  241. auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{});
  242. Tensor cK = cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k)
  243. // Repeat the partitioning with identity layouts
  244. Tensor tKcK = gmem_thr_copy_kv.partition_S(cK);
  245. Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK);
  246. GmemTiledCopyKVStore gmem_tiled_copy_kv_store;
  247. int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{}));
  248. #pragma unroll
  249. for (int m = 0; m < size<1>(tVrV); ++m) {
  250. bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit;
  251. Element* v_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));
  252. Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape<Int<kHeadDim>>{});
  253. Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape<Int<kGmemElemsPerLoad>>{});
  254. if (should_load) {
  255. #pragma unroll
  256. for (int k = 0; k < size<2>(tVrV); ++k) {
  257. int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad;
  258. if (tKpK(_0{}, k)) {
  259. cute::copy(gmem_tiled_copy_kv_store, tVrV(_, m, k), mV_paged_cur_copy(_, ki));
  260. }
  261. }
  262. }
  263. }
  264. if constexpr (!KV_Same_Iter) { compute_V_ptr(); }
  265. };
  266. };
  267. } // namespace flash