1
0

rotary.h 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cute/tensor.hpp>
  6. #include "utils.h"
  7. namespace flash {
  8. using namespace cute;
  9. ////////////////////////////////////////////////////////////////////////////////////////////////////
  10. template <typename Engine1, typename Layout1, typename Engine2, typename Layout2>
  11. CUTLASS_DEVICE void
  12. apply_rotary_interleaved(Tensor<Engine1, Layout1> &rK,
  13. Tensor<Engine2, Layout2> const &rCos,
  14. Tensor<Engine2, Layout2> const &rSin) {
  15. CUTE_STATIC_ASSERT_V(rank(rK) == _1{});
  16. CUTE_STATIC_ASSERT_V(rank(rCos) == _1{});
  17. CUTE_STATIC_ASSERT_V(rank(rSin) == _1{});
  18. CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin));
  19. static_assert(decltype(size<0>(rK))::value == decltype(size<0>(rCos))::value * 2);
  20. static_assert(decltype(size<0>(rCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
  21. Tensor K_fp32 = make_tensor_like<float>(rK);
  22. convert_type_out(rK, K_fp32);
  23. Tensor cos_fp32 = make_tensor_like<float>(rCos);
  24. convert_type_out(rCos, cos_fp32);
  25. Tensor sin_fp32 = make_tensor_like<float>(rSin);
  26. convert_type_out(rSin, sin_fp32);
  27. #pragma unroll
  28. for (int i = 0; i < size<0>(K_fp32) / 2; ++i) {
  29. float real = K_fp32[2 * i] * cos_fp32[i] - K_fp32[2 * i + 1] * sin_fp32[i];
  30. float imag = K_fp32[2 * i] * sin_fp32[i] + K_fp32[2 * i + 1] * cos_fp32[i];
  31. K_fp32[2 * i] = real;
  32. K_fp32[2 * i + 1] = imag;
  33. }
  34. convert_type_out(K_fp32, rK);
  35. }
  36. ////////////////////////////////////////////////////////////////////////////////////////////////////
  37. template <typename Engine1, typename Layout1, typename Engine2, typename Layout2>
  38. CUTLASS_DEVICE void
  39. apply_rotary_contiguous(Tensor<Engine1, Layout1> &rK_left,
  40. Tensor<Engine1, Layout1> &rK_right,
  41. Tensor<Engine2, Layout2> const &rCos,
  42. Tensor<Engine2, Layout2> const &rSin) {
  43. CUTE_STATIC_ASSERT_V(rank(rK_left) == _1{});
  44. CUTE_STATIC_ASSERT_V(rank(rK_right) == _1{});
  45. CUTE_STATIC_ASSERT_V(rank(rCos) == _1{});
  46. CUTE_STATIC_ASSERT_V(rank(rSin) == _1{});
  47. CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rK_right));
  48. CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rCos));
  49. CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin));
  50. static_assert(decltype(size<0>(rCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
  51. Tensor K_left_fp32 = make_tensor_like<float>(rK_left);
  52. convert_type_out(rK_left, K_left_fp32);
  53. Tensor K_right_fp32 = make_tensor_like<float>(rK_right);
  54. convert_type_out(rK_right, K_right_fp32);
  55. Tensor cos_fp32 = make_tensor_like<float>(rCos);
  56. convert_type_out(rCos, cos_fp32);
  57. Tensor sin_fp32 = make_tensor_like<float>(rSin);
  58. convert_type_out(rSin, sin_fp32);
  59. #pragma unroll
  60. for (int i = 0; i < size<0>(K_left_fp32); ++i) {
  61. float real = K_left_fp32[i] * cos_fp32[i] - K_right_fp32[i] * sin_fp32[i];
  62. float imag = K_left_fp32[i] * sin_fp32[i] + K_right_fp32[i] * cos_fp32[i];
  63. K_left_fp32[i] = real;
  64. K_right_fp32[i] = imag;
  65. }
  66. convert_type_out(K_left_fp32, rK_left);
  67. convert_type_out(K_right_fp32, rK_right);
  68. }
  69. ////////////////////////////////////////////////////////////////////////////////////////////////////
  70. template <int kBlockMN, int kHeadDim, int NumThreads, typename Element, bool FixedPosition=false>
  71. struct Rotary {
  72. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  73. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  74. // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each
  75. // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.
  76. // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved
  77. // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will
  78. // load twice from the same row.
  79. static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element);
  80. static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
  81. static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
  82. static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow");
  83. // We assume threads loading the same row are in the same warp.
  84. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp");
  85. using LayoutAtom = Layout<Shape <Int<NumThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  86. Stride<Int<kGmemThreadsPerRow>, _1>>;
  87. using TiledCopyQK = decltype(
  88. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  89. LayoutAtom{},
  90. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
  91. using GmemTiledCopyRotary = decltype(
  92. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<64>, Element>{},
  93. LayoutAtom{},
  94. Layout<Shape<_1, Int<kGmemElemsPerLoad / 2>>>{})); // Val layout, 4 or 8 vals per store
  95. using GmemTiledCopyRotaryCont = decltype(
  96. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  97. LayoutAtom{},
  98. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
  99. using ShapeRotary = cute::Shape<int32_t, int32_t>; // (seqlen_ro, rotary_dim // 2)
  100. using StrideRotary = cute::Stride<int64_t, _1>;
  101. using GmemThrCopyRotary = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0)));
  102. using GmemThrCopyRotaryCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)));
  103. using TensortRcR = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{})));
  104. using TensortRpR = decltype(make_tensor<bool>(make_shape(size<2>(TensortRcR{}))));
  105. using TensortRcRCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{})));
  106. using TensortRpRCont = decltype(make_tensor<bool>(make_shape(size<2>(TensortRcRCont{}))));
  107. using TensormR = decltype(make_tensor(
  108. make_gmem_ptr((Element const*)nullptr),
  109. ShapeRotary{},
  110. make_stride(cute::conditional_return<FixedPosition>(_0{}, int64_t(0)), _1{})));
  111. using TensortRgR = decltype(
  112. GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_S(make_tensor(
  113. make_gmem_ptr((Element const*)nullptr),
  114. make_shape(Int<kBlockMN>{}, Int<kHeadDim / 2>{}, int(0)),
  115. make_stride(cute::conditional_return<FixedPosition>(_0{}, int64_t(0)), _1{}, cute::conditional_return<FixedPosition>(_0{}, int64_t(0))))));
  116. using TensortRgRCont = decltype(
  117. GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_S(make_tensor(
  118. make_gmem_ptr((Element const*)nullptr),
  119. make_shape(Int<kBlockMN>{}, Int<kHeadDim / 2>{}, int(0)),
  120. make_stride(cute::conditional_return<FixedPosition>(_0{}, int64_t(0)), _1{}, cute::conditional_return<FixedPosition>(_0{}, int64_t(0))))));
  121. GmemTiledCopyRotary gmem_tiled_copy_rotary;
  122. GmemTiledCopyRotaryCont gmem_tiled_copy_rotary_cont;
  123. bool const is_rotary_interleaved;
  124. int const rotary_dim;
  125. int const thread_idx;
  126. int const max_seqlen;
  127. GmemThrCopyRotary const gmem_thr_copy_rotary;
  128. GmemThrCopyRotaryCont const gmem_thr_copy_rotary_cont;
  129. TensortRpR tRpR;
  130. TensortRpRCont tRpRCont;
  131. TensormR mCos, mSin;
  132. TensortRgR tRgCos, tRgSin;
  133. TensortRgRCont tRgCosCont, tRgSinCont;
  134. CUTLASS_DEVICE
  135. Rotary(Element const* const ptr_rotary_cos, ShapeRotary const &shape_rotary, StrideRotary const &stride_rotary_cos_,
  136. Element const* const ptr_rotary_sin, StrideRotary const &stride_rotary_sin_,
  137. bool const is_rotary_interleaved, int const thread_idx, int const max_seqlen, int const start_idx)
  138. : is_rotary_interleaved(is_rotary_interleaved)
  139. , rotary_dim(get<1>(shape_rotary) * 2)
  140. , thread_idx(thread_idx)
  141. , max_seqlen(max_seqlen)
  142. , gmem_thr_copy_rotary(gmem_tiled_copy_rotary.get_thread_slice(thread_idx))
  143. , gmem_thr_copy_rotary_cont(gmem_tiled_copy_rotary_cont.get_thread_slice(thread_idx))
  144. {
  145. auto stride_rotary_cos = make_stride(cute::conditional_return<!FixedPosition>(get<0>(stride_rotary_cos_), _0{}), get<1>(stride_rotary_cos_));
  146. auto stride_rotary_sin = make_stride(cute::conditional_return<!FixedPosition>(get<0>(stride_rotary_sin_), _0{}), get<1>(stride_rotary_sin_));
  147. mCos = make_tensor(make_gmem_ptr(ptr_rotary_cos + start_idx * get<0>(stride_rotary_cos_)), shape_rotary, stride_rotary_cos);
  148. mSin = make_tensor(make_gmem_ptr(ptr_rotary_sin + start_idx * get<0>(stride_rotary_sin_)), shape_rotary, stride_rotary_sin);
  149. Tensor gCos = local_tile(mCos, Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{}, make_coord(_, _0{})); // (MN, K / 2, _)
  150. Tensor gSin = local_tile(mSin, Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{}, make_coord(_, _0{})); // (MN, K / 2, _)
  151. tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
  152. tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
  153. tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCos);
  154. tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSin);
  155. Tensor cR = cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{}); // (BLK_N,BLK_K / 2)
  156. Tensor tRcR = gmem_thr_copy_rotary.partition_D(cR);
  157. tRpR = make_tensor<bool>(make_shape(size<2>(tRcR)));
  158. #pragma unroll
  159. for (int k = 0; k < size(tRpR); ++k) { tRpR(k) = get<1>(tRcR(_0{}, _0{}, k)) < get<1>(shape_rotary); }
  160. Tensor tRcRCont = gmem_thr_copy_rotary_cont.partition_D(cR);
  161. tRpRCont = make_tensor<bool>(make_shape(size<2>(tRcRCont)));
  162. #pragma unroll
  163. for (int k = 0; k < size(tRpRCont); ++k) { tRpRCont(k) = get<1>(tRcRCont(_0{}, _0{}, k)) < get<1>(shape_rotary); }
  164. };
  165. template <bool kInterleaved=true>
  166. CUTLASS_DEVICE
  167. auto load_cos_sin(int const block) {
  168. using GmemTiledCopyRo = std::conditional_t<kInterleaved, GmemTiledCopyRotary, GmemTiledCopyRotaryCont>;
  169. auto gmem_thr_copy_ro = cute::conditional_return<kInterleaved>(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont);
  170. Tensor tRpRCur = cute::conditional_return<kInterleaved>(tRpR, tRpRCont);
  171. Tensor tRgCosCur = cute::conditional_return<kInterleaved>(tRgCos, tRgCosCont)(_, _, _, block);
  172. Tensor tRgSinCur = cute::conditional_return<kInterleaved>(tRgSin, tRgSinCont)(_, _, _, block);
  173. // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way
  174. Tensor tRrCos = make_tensor_like(tRgCosCur);
  175. Tensor tRrSin = make_tensor_like(tRgSinCur);
  176. Tensor cR = cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{}); // (BLK_N,BLK_K / 2)
  177. Tensor tRcR = gmem_thr_copy_ro.partition_D(cR);
  178. // If FixedPosition, only copy the first row as we only need the cos/sin for position cache_seqlens
  179. #pragma unroll
  180. for (int m = 0; m < (!FixedPosition ? size<1>(tRrCos) : 1); ++m) {
  181. if (get<0>(tRcR(_0{}, m, _0{})) < std::min(max_seqlen - block * kBlockMN, kBlockMN)) {
  182. #pragma unroll
  183. for (int k = 0; k < size<2>(tRrCos); ++k) {
  184. if (tRpRCur(k)) {
  185. cute::copy(GmemTiledCopyRo{}, tRgCosCur(_, m, k), tRrCos(_, m, k));
  186. cute::copy(GmemTiledCopyRo{}, tRgSinCur(_, m, k), tRrSin(_, m, k));
  187. }
  188. }
  189. }
  190. }
  191. return cute::make_tuple(tRrCos, tRrSin);;
  192. }
  193. template <bool kInterleaved=true>
  194. CUTLASS_DEVICE
  195. auto load_cos_sin_packgqa(int const block, cutlass::FastDivmod const &qhead_per_khead_divmod) {
  196. static constexpr int kGmemElemsPerLoadCur = kInterleaved ? kGmemElemsPerLoad / 2 : kGmemElemsPerLoad;
  197. using GmemTiledCopyRo = std::conditional_t<kInterleaved, GmemTiledCopyRotary, GmemTiledCopyRotaryCont>;
  198. auto gmem_thr_copy_ro = cute::conditional_return<kInterleaved>(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont);
  199. Tensor tRpRCur = cute::conditional_return<kInterleaved>(tRpR, tRpRCont);
  200. // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way
  201. Tensor tRrCos = make_tensor_like(cute::conditional_return<kInterleaved>(tRgCos, tRgCosCont)(_, _, _, _0{}));
  202. Tensor tRrSin = make_tensor_like(cute::conditional_return<kInterleaved>(tRgSin, tRgSinCont)(_, _, _, _0{}));
  203. int const qhead_per_khead = qhead_per_khead_divmod.divisor;
  204. Tensor cR = cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{}); // (BLK_N,BLK_K / 2)
  205. Tensor tRcR = gmem_thr_copy_ro.partition_D(cR);
  206. // The main bottleneck here is actually instruction cache misses.
  207. // Similar to PagedKV, it's expensive to compute the pointers.
  208. // We split the work among threads loading the same row, then __shfl_sync the pointers.
  209. static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size<1>(tRrCos)), kGmemThreadsPerRow);
  210. Tensor tPrCosPtr = make_tensor<Element const*>(Shape<Int<NumPtrPerThread>>{});
  211. Tensor tPrSinPtr = make_tensor<Element const*>(Shape<Int<NumPtrPerThread>>{});
  212. #pragma unroll
  213. for (int i = 0; i < NumPtrPerThread; ++i) {
  214. int const row = i * NumThreads + get<0>(tRcR(_0{}, thread_idx % kGmemThreadsPerRow, _0{}));
  215. int const idx = block * kBlockMN + row;
  216. int row_actual = qhead_per_khead_divmod.divide(idx);
  217. tPrCosPtr[i] = &mCos(row_actual, _0{});
  218. tPrSinPtr[i] = &mSin(row_actual, _0{});
  219. }
  220. #pragma unroll
  221. for (int m = 0; m < (!FixedPosition ? size<1>(tRgCos) : 1); ++m) {
  222. int const idx = block * kBlockMN + get<0>(tRcR(_0{}, m, _0{}));
  223. Element const* cos_ptr = reinterpret_cast<Element const*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrCosPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));
  224. Element const* sin_ptr = reinterpret_cast<Element const*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrSinPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));
  225. if (idx < max_seqlen * qhead_per_khead) {
  226. Tensor mCos_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(cos_ptr), Shape<Int<kHeadDim / 2>>{}),
  227. Shape<Int<kGmemElemsPerLoadCur>>{});
  228. Tensor mSin_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(sin_ptr), Shape<Int<kHeadDim / 2>>{}),
  229. Shape<Int<kGmemElemsPerLoadCur>>{});
  230. #pragma unroll
  231. for (int k = 0; k < size<2>(tRgCos); ++k) {
  232. int const ki = get<1>(tRcR(_0{}, _0{}, k)) / (kGmemElemsPerLoadCur);
  233. if (tRpRCur(k)) {
  234. cute::copy(GmemTiledCopyRo{}, mCos_copy(_, ki), tRrCos(_, m, k));
  235. cute::copy(GmemTiledCopyRo{}, mSin_copy(_, ki), tRrSin(_, m, k));
  236. }
  237. }
  238. }
  239. }
  240. return cute::make_tuple(tRrCos, tRrSin);
  241. }
  242. template <typename TensorsQ, typename TensortRrR>
  243. CUTLASS_DEVICE
  244. void
  245. apply_Q_interleaved(TensorsQ &sQ, // (kBlockM, kHeadDim)
  246. TensortRrR const &tRrCos, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary
  247. TensortRrR const &tRrSin, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary
  248. int const m_block, int const qhead_per_khead=1)
  249. {
  250. TiledCopyQK tiled_copy_q;
  251. auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx);
  252. Tensor tQsQ = gmem_thr_copy_q.partition_S(sQ);
  253. Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim>>{}));
  254. CUTE_STATIC_ASSERT_V(rank(tQsQ) == _3{});
  255. CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{});
  256. CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{});
  257. CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrCos));
  258. CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrCos));
  259. CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrSin));
  260. CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrSin));
  261. CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin));
  262. static_assert(decltype(size<0>(tQsQ))::value == decltype(size<0>(tRrCos))::value * 2);
  263. static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
  264. #pragma unroll
  265. for (int m = 0; m < size<1>(tQsQ); ++m) {
  266. if (get<0>(tQcQ(_0{}, m, _0{})) < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) {
  267. #pragma unroll
  268. for (int k = 0; k < size<2>(tQsQ); ++k) {
  269. if (tRpR(k)) {
  270. Tensor rQ = make_fragment_like(tQsQ(_, m, k));
  271. cute::copy(tiled_copy_q, tQsQ(_, m, k), rQ);
  272. apply_rotary_interleaved(rQ, tRrCos(_, m, k), tRrSin(_, m, k));
  273. cute::copy(tiled_copy_q, rQ, tQsQ(_, m, k));
  274. }
  275. }
  276. }
  277. }
  278. };
  279. template <typename TensorsQ, typename TensortRrR>
  280. CUTLASS_DEVICE
  281. void
  282. apply_Q_contiguous(TensorsQ &sQ, // (kBlockM, kHeadDim)
  283. TensortRrR const &tRrCosCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont
  284. TensortRrR const &tRrSinCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont
  285. int const m_block, int const qhead_per_khead=1)
  286. {
  287. TiledCopyQK tiled_copy_q;
  288. auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx);
  289. Tensor sQ_copy = cute::tiled_divide(sQ, Shape<_1, Int<kGmemElemsPerLoad>>{});
  290. Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{}));
  291. CUTE_STATIC_ASSERT_V(rank(tQcQ) == _3{});
  292. CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{});
  293. CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{});
  294. CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrCosCont));
  295. CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrCosCont));
  296. CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrSinCont));
  297. CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrSinCont));
  298. CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont));
  299. CUTE_STATIC_ASSERT_V(size<0>(tQcQ) == size<0>(tRrCosCont));
  300. static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
  301. #pragma unroll
  302. for (int m = 0; m < size<1>(tQcQ); ++m) {
  303. int const row = get<0>(tQcQ(_0{}, m, _0{}));
  304. if (row < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) {
  305. #pragma unroll
  306. for (int k = 0; k < size<2>(tQcQ); ++k) {
  307. int const col = get<1>(tQcQ(_0{}, _0{}, k));
  308. if (col < rotary_dim / 2) {
  309. int const col_idx_left = col / kGmemElemsPerLoad;
  310. int const col_idx_right = col / kGmemElemsPerLoad + rotary_dim / (2 * kGmemElemsPerLoad);
  311. Tensor rQ_left = make_fragment_like(sQ_copy(_, row, col_idx_left));
  312. cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_left), rQ_left);
  313. Tensor rQ_right = make_fragment_like(rQ_left);
  314. cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_right), rQ_right);
  315. apply_rotary_contiguous(rQ_left, rQ_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k));
  316. cute::copy(tiled_copy_q, rQ_left, sQ_copy(_, row, col_idx_left));
  317. cute::copy(tiled_copy_q, rQ_right, sQ_copy(_, row, col_idx_right));
  318. }
  319. }
  320. }
  321. }
  322. };
  323. template <bool PagedKV=false, typename TensorsK, typename TensorgK, typename TensorpK, typename TensortRrR, typename TensorKPtr>
  324. CUTLASS_DEVICE
  325. void
  326. apply_K_interleaved(TensorsK const &sK, // (kBlockN, kHeadDim)
  327. TensorgK &gK, // (kBlockN, kHeadDim)
  328. TensorpK const &tKpK, // (kBlockN, kHeadDim) split according to ThrCopyKV
  329. TensortRrR const &tRrCos, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary
  330. TensortRrR const &tRrSin, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary
  331. TensorKPtr const &tPrKPtr,
  332. int const n_block)
  333. {
  334. TiledCopyQK tiled_copy_k;
  335. auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx);
  336. Tensor tKsK = gmem_thr_copy_q.partition_S(sK);
  337. Tensor tKgK = gmem_thr_copy_q.partition_S(gK);
  338. Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim>>{}));
  339. CUTE_STATIC_ASSERT_V(rank(tKsK) == _3{});
  340. CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{});
  341. CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{});
  342. CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrCos));
  343. CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrCos));
  344. CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrSin));
  345. CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrSin));
  346. CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin));
  347. static_assert(decltype(size<0>(tKsK))::value == decltype(size<0>(tRrCos))::value * 2);
  348. static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
  349. if constexpr (PagedKV) {
  350. static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow));
  351. }
  352. #pragma unroll
  353. for (int m = 0; m < size<1>(tKsK); ++m) {
  354. int const row = get<0>(tKcK(_0{}, m, _0{}));
  355. auto mK_cur_copy = [&] {
  356. if constexpr (PagedKV) {
  357. Element* k_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow));
  358. Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape<Int<kHeadDim>>{});
  359. return cute::tiled_divide(mK_cur, Shape<Int<kGmemElemsPerLoad>>{});
  360. } else {
  361. return nullptr;
  362. }
  363. }();
  364. if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) {
  365. #pragma unroll
  366. for (int k = 0; k < size<2>(tKsK); ++k) {
  367. if (tKpK(k)) {
  368. Tensor rK = make_fragment_like(tKsK(_, m, k));
  369. cute::copy(tiled_copy_k, tKsK(_, m, k), rK);
  370. if (tRpR(k)) { apply_rotary_interleaved(rK, tRrCos(_, m, k), tRrSin(_, m, k)); }
  371. if constexpr (!PagedKV) {
  372. cute::copy(tiled_copy_k, rK, tKgK(_, m, k));
  373. } else {
  374. int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad;
  375. cute::copy(tiled_copy_k, rK, mK_cur_copy(_, ki));
  376. }
  377. }
  378. }
  379. }
  380. }
  381. };
  382. template <bool PagedKV=false, typename TensorsK, typename TensorgK, typename TensorpK, typename TensortRrR, typename TensorKPtr>
  383. CUTLASS_DEVICE
  384. void
  385. apply_K_contiguous(TensorsK const &sK, // (kBlockN, kHeadDim)
  386. TensorgK &gK, // (kBlockN, kHeadDim)
  387. TensorpK const &tKpK, // (kBlockN, kHeadDim) split according to ThrCopyKV
  388. TensortRrR const &tRrCosCont, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont
  389. TensortRrR const &tRrSinCont, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont
  390. TensorKPtr const &tPrKPtr,
  391. int const n_block, int const max_k)
  392. {
  393. TiledCopyQK tiled_copy_k;
  394. auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx);
  395. Tensor sK_copy = cute::tiled_divide(sK, Shape<_1, Int<kGmemElemsPerLoad>>{});
  396. Tensor gK_copy = cute::tiled_divide(gK, Shape<_1, Int<kGmemElemsPerLoad>>{});
  397. Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{}));
  398. CUTE_STATIC_ASSERT_V(rank(tKcK) == _3{});
  399. CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{});
  400. CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{});
  401. CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrCosCont));
  402. CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrCosCont));
  403. CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrSinCont));
  404. CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrSinCont));
  405. CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont));
  406. CUTE_STATIC_ASSERT_V(size<0>(tKcK) == size<0>(tRrCosCont));
  407. static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
  408. if constexpr (PagedKV) {
  409. static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow));
  410. }
  411. const int ro_dim_vec = rotary_dim / kGmemElemsPerLoad;
  412. const int non_ro_dim_vec = (max_k - rotary_dim) / kGmemElemsPerLoad;
  413. #pragma unroll
  414. for (int m = 0; m < size<1>(tKcK); ++m) {
  415. int const row = get<0>(tKcK(_0{}, m, _0{}));
  416. Tensor gK_cur_copy = [&] {
  417. if constexpr (PagedKV) {
  418. Element* k_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow));
  419. Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape<Int<kHeadDim>>{});
  420. return cute::tiled_divide(mK_cur, Shape<Int<kGmemElemsPerLoad>>{});
  421. } else {
  422. return gK_copy(_, row, _);
  423. }
  424. }();
  425. if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) {
  426. #pragma unroll
  427. for (int k = 0; k < size<2>(tKcK); ++k) {
  428. if (tKpK(k)) {
  429. int const col = get<1>(tKcK(_0{}, _0{}, k));
  430. bool rotate = col < rotary_dim / 2;
  431. int const col_idx_left = rotate ? col / kGmemElemsPerLoad : (col + rotary_dim / 2) / kGmemElemsPerLoad;
  432. int const col_idx_right = col_idx_left + (rotate ? ro_dim_vec / 2 : non_ro_dim_vec / 2);
  433. Tensor rK_left = make_fragment_like(sK_copy(_, row, col_idx_left));
  434. cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_left), rK_left);
  435. Tensor rK_right = make_fragment_like(rK_left);
  436. cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_right), rK_right);
  437. if (rotate) {
  438. apply_rotary_contiguous(rK_left, rK_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k));
  439. }
  440. cute::copy(tiled_copy_k, rK_left, gK_cur_copy(_, col_idx_left));
  441. if (col_idx_right * kGmemElemsPerLoad < max_k) {
  442. cute::copy(tiled_copy_k, rK_right, gK_cur_copy(_, col_idx_right));
  443. }
  444. }
  445. }
  446. }
  447. }
  448. };
  449. };
  450. ////////////////////////////////////////////////////////////////////////////////////////////////////
  451. } // namespace flash