pack_gqa.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cute/tensor.hpp>
  6. #include "cutlass/fast_math.h" // For cutlass::FastDivmod
  7. #include "utils.h"
  8. namespace flash {
  9. using namespace cute;
  10. template <int kBlockM, int kHeadDim, int NumThreads, typename Element>
  11. struct PackGQAManager {
  12. // We use CpAsync for Q, since TMA doesn't work there
  13. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  14. static constexpr int kGmemElemsPerStore = kGmemElemsPerLoad;
  15. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  16. // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each
  17. // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.
  18. // In the case of PackGQA, this reduces the number of times we need to call divmod.
  19. static constexpr int kBytePerRow = kHeadDim * sizeof(Element);
  20. static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
  21. static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
  22. static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow");
  23. // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where
  24. // these threads share the same page table entry and share the work of computing pointers to paged K and paged V.
  25. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp");
  26. using GmemCopyAtomCpAsync = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, Element>;
  27. using GmemLayoutAtom = Layout<Shape <Int<NumThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  28. Stride<Int<kGmemThreadsPerRow>, _1>>;
  29. using GmemTiledCopyQCpAsync = decltype(
  30. make_tiled_copy(GmemCopyAtomCpAsync{},
  31. GmemLayoutAtom{},
  32. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
  33. // Was trying to have each WG loading Q to the rows in sQ that only that WG needs so that we only need
  34. // to sync within each WG, but didn't seem to be any faster.
  35. // using GmemLayoutAtomWG = Layout<Shape <Int<128 / kGmemThreadsPerRow>, Int<NumThreads / 128>, Int<kGmemThreadsPerRow> >,
  36. // Stride<Int<kGmemThreadsPerRow>, _128, _1>>;
  37. // using GmemTiledCopyQCpAsyncWG = decltype(
  38. // make_tiled_copy(GmemCopyAtomCpAsync{},
  39. // GmemLayoutAtomNew{},
  40. // Layout<Shape<_1, _1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
  41. using GmemTiledCopyO = decltype(
  42. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  43. GmemLayoutAtom{},
  44. Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); // Val layout, 8 or 16 vals per store
  45. template <int NumThreadsPerRow=kGmemThreadsPerRow, typename Engine, typename Layout, typename TensorC>
  46. CUTLASS_DEVICE
  47. static auto
  48. compute_ptr(Tensor<Engine, Layout> &tensor, TensorC const &tRows,
  49. cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const m_block) {
  50. // tensor of shape ((qhead_per_khead, seqlen_q))
  51. static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size(tRows)), NumThreadsPerRow);
  52. using TensorType = typename Engine::value_type;
  53. Tensor tPrPtr = make_tensor<TensorType const*>(Shape<Int<NumPtrPerThread>>{});
  54. #pragma unroll
  55. for (int i = 0; i < NumPtrPerThread; ++i) {
  56. int const row = i * NumThreads + get<0>(tRows(thread_idx % NumThreadsPerRow));
  57. int const idx = m_block * kBlockM + row;
  58. int m_idx, h_idx;
  59. m_idx = qhead_per_khead_divmod.divmod(h_idx, idx);
  60. tPrPtr[i] = &tensor(make_coord(make_coord(h_idx, m_idx)));
  61. }
  62. return tPrPtr;
  63. }
  64. template <typename TensormQ, typename TensorsQ>
  65. CUTLASS_DEVICE
  66. static void
  67. load_Q(TensormQ const &mQ, // ((qhead_per_khead, seqlen_q), headdim)
  68. TensorsQ &sQ, // (kBlockM, kHeadDim)
  69. cutlass::FastDivmod const &qhead_per_khead_divmod,
  70. int const thread_idx, int const seqlen_q, int const m_block
  71. )
  72. {
  73. GmemTiledCopyQCpAsync gmem_tiled_copy_Q_cp_async;
  74. // GmemTiledCopyQCpAsyncNew gmem_tiled_copy_Q_cp_async;
  75. auto gmem_thr_copy_Q_cp_async = gmem_tiled_copy_Q_cp_async.get_thread_slice(thread_idx);
  76. Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  77. Tensor tQcQ = gmem_thr_copy_Q_cp_async.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  78. Tensor tQsQ = gmem_thr_copy_Q_cp_async.partition_D(sQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  79. // Tensor tQcQ_ = gmem_thr_copy_Q_cp_async.partition_S(cute::flat_divide(cQ, _64{})); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  80. // Tensor tQsQ_ = gmem_thr_copy_Q_cp_async.partition_D(cute::flat_divide(sQ, _64{})); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  81. // Tensor tQcQ = group_modes<1, rank(tQcQ_) - 1>(tQcQ_);
  82. // Tensor tQsQ = group_modes<1, rank(tQsQ_) - 1>(tQsQ_);
  83. Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
  84. #pragma unroll
  85. for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < size<1>(mQ); }
  86. // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for Q.
  87. // We split the work among threads loading the same row of Q, then __shfl_sync the pointers.
  88. Tensor mQ_0 = mQ(_, _0{});
  89. Tensor tQcQ_row = tQcQ(_0{}, _, _0{});
  90. Tensor tPrQPtr = compute_ptr(mQ_0, tQcQ_row, qhead_per_khead_divmod, thread_idx, m_block);
  91. int const qhead_per_khead = qhead_per_khead_divmod.divisor;
  92. #pragma unroll
  93. for (int m = 0; m < size<1>(tQsQ); ++m) {
  94. int idx = m_block * kBlockM + get<0>(tQcQ(_0{}, m, _0{}));
  95. Element const* q_ptr = reinterpret_cast<Element const*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrQPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));
  96. if (idx < seqlen_q * qhead_per_khead) {
  97. // if (thread_idx == 0) { printf("m: %d, m_idx: %d, h_idx: %d, q_ptr = %p, q_ptr_og = %p\n", m, m_idx, h_idx, q_ptr, &mQ_copy(0, make_coord(h_idx, m_idx), 0));}
  98. Tensor mQ_cur = make_tensor(make_gmem_ptr(q_ptr), Shape<Int<kHeadDim>>{});
  99. Tensor mQ_cur_copy = cute::tiled_divide(mQ_cur, Shape<Int<kGmemElemsPerLoad>>{});
  100. #pragma unroll
  101. for (int k = 0; k < size<2>(tQsQ); ++k) {
  102. int ki = get<1>(tQcQ(_0{}, _0{}, k)) / kGmemElemsPerLoad;
  103. // the "tiled_copy.with(tQpQ(k))"" will fill in zero for columns where tQpQ(k) is false
  104. // TODO: check this
  105. cute::copy(gmem_tiled_copy_Q_cp_async.with(tQpQ(k)), mQ_cur_copy(_, ki), tQsQ(_, m, k));
  106. }
  107. } // Don't need to fill in 0s for sQ since we're not gonna write the output to gmem for those rows
  108. }
  109. };
  110. template <typename TensormLSE, typename TensorsLSE, typename TiledMma>
  111. CUTLASS_DEVICE
  112. static void
  113. store_LSE(TensormLSE &mLSE, // ((qhead_per_khead, seqlen_q))
  114. TensorsLSE const &tLSErLSE, // (kBlockM) split across threads according to tiled_mma
  115. TiledMma tiled_mma,
  116. cutlass::FastDivmod const &qhead_per_khead_divmod,
  117. int const thread_idx, int const seqlen_o, int const m_block
  118. )
  119. {
  120. Tensor caccO = cute::make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
  121. auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
  122. Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
  123. // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
  124. Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
  125. CUTE_STATIC_ASSERT_V(size(tLSErLSE) == size(taccOcO_row)); // MMA_M
  126. // If PackGQA, we split the work of compute divmod among threads in the same row
  127. static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{});
  128. static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0);
  129. static_assert(CUTE_STATIC_V(size(tLSErLSE)) <= kMmaThreadsPerRow);
  130. static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow);
  131. Tensor tPrLSEPtr = compute_ptr<kMmaThreadsPerRow>(mLSE, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block);
  132. static_assert(CUTE_STATIC_V(size(tPrLSEPtr)) == 1);
  133. int const qhead_per_khead = qhead_per_khead_divmod.divisor;
  134. #pragma unroll
  135. for (int mi = 0; mi < size(tLSErLSE); ++mi) {
  136. int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));
  137. float* ptr_LSE_cur = reinterpret_cast<float*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrLSEPtr[0]), mi % kMmaThreadsPerRow, kMmaThreadsPerRow));
  138. if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o * qhead_per_khead) {
  139. *ptr_LSE_cur = tLSErLSE(mi);
  140. }
  141. }
  142. };
  143. template <typename TensormO, typename TensorrO>
  144. CUTLASS_DEVICE
  145. static void
  146. store_O(TensormO &mO, // ((qhead_per_khead, seqlen_o), headdim)
  147. TensorrO const &tOrO, // (kBlockM, kHeadDim) split across threads according to gmem_tiled_copy_O
  148. cutlass::FastDivmod const &qhead_per_khead_divmod,
  149. int const thread_idx, int const seqlen_o, int const m_block
  150. )
  151. {
  152. GmemTiledCopyO gmem_tiled_copy_O;
  153. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  154. Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  155. Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  156. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
  157. #pragma unroll
  158. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < size<1>(mO); }
  159. // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O.
  160. // We split the work among threads loading the same row of O, then __shfl_sync the pointers.
  161. Tensor mO_0 = mO(_, _0{});
  162. Tensor tOcO_row = tOcO(_0{}, _, _0{});
  163. Tensor tPrOPtr = compute_ptr(mO_0, tOcO_row, qhead_per_khead_divmod, thread_idx, m_block);
  164. int const qhead_per_khead = qhead_per_khead_divmod.divisor;
  165. #pragma unroll
  166. for (int m = 0; m < size<1>(tOrO); ++m) {
  167. int idx = m_block * kBlockM + get<0>(tOcO(_0{}, m, _0{}));
  168. Element* o_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrOPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));
  169. if (idx < seqlen_o * qhead_per_khead) {
  170. Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape<Int<kHeadDim>>{});
  171. Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape<Int<kGmemElemsPerStore>>{});
  172. #pragma unroll
  173. for (int k = 0; k < size<2>(tOrO); ++k) {
  174. int ki = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerStore;
  175. if (tOpO(k)) {
  176. cute::copy(gmem_tiled_copy_O, tOrO(_, m, k), mO_cur_copy(_, ki));
  177. }
  178. }
  179. }
  180. }
  181. };
  182. template <typename TensormO, typename TensorrO, typename TiledMma>
  183. CUTLASS_DEVICE
  184. static void
  185. store_O_direct(TensormO &mO, // ((qhead_per_khead, seqlen_o), headdim)
  186. TensorrO const &tOrO, // (kBlockM, kHeadDim) split across threads according to tiled_mma
  187. TiledMma tiled_mma,
  188. cutlass::FastDivmod const &qhead_per_khead_divmod,
  189. int const thread_idx, int const seqlen_o, int const m_block
  190. )
  191. {
  192. static constexpr int kGmemElemsPerStoreDirect = 2;
  193. cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element> gmem_copy_direct;
  194. // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  195. Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));
  196. Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
  197. Tensor caccO = cute::make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
  198. auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
  199. Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
  200. // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
  201. Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
  202. Tensor taccOcO_col = taccOcO(make_coord(_, _0{}, _), _0{}, _);
  203. // If PackGQA, we split the work of compute divmod among threads in the same row
  204. static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{});
  205. static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0);
  206. static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow);
  207. // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O.
  208. // We split the work among threads loading the same row of O, then __shfl_sync the pointers.
  209. Tensor mO_0 = mO(_, _0{});
  210. Tensor tPrOPtr = compute_ptr<kMmaThreadsPerRow>(mO_0, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block);
  211. static_assert(CUTE_STATIC_V(size(tPrOPtr)) == 1);
  212. int const qhead_per_khead = qhead_per_khead_divmod.divisor;
  213. #pragma unroll
  214. for (int m = 0; m < size<1>(tOrO_copy); ++m) {
  215. int row = m_block * kBlockM + get<0>(taccOcO_row(m));
  216. Element* o_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrOPtr[0]), m % kMmaThreadsPerRow, kMmaThreadsPerRow));
  217. if (row < seqlen_o * qhead_per_khead) {
  218. Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape<Int<kHeadDim>>{});
  219. Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape<Int<kGmemElemsPerStoreDirect>>{});
  220. #pragma unroll
  221. for (int k = 0; k < size<2>(tOrO_copy); ++k) {
  222. int col = get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect));
  223. if (col < size<1>(mO)) {
  224. cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), mO_cur_copy(_, col / kGmemElemsPerStoreDirect));
  225. }
  226. }
  227. }
  228. }
  229. };
  230. };
  231. } // namespace flash