flash_fwd_kernel.h 74 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cute/tensor.hpp>
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/array.h>
  8. #include <cutlass/numeric_types.h>
  9. #include "block_info.h"
  10. #include "kernel_traits.h"
  11. #include "utils.h"
  12. #include "softmax.h"
  13. #include "mask.h"
  14. #include "dropout.h"
  15. #include "rotary.h"
  16. namespace flash {
  17. using namespace cute;
  18. ////////////////////////////////////////////////////////////////////////////////////////////////////
  19. template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
  20. __forceinline__ __device__ auto get_lse_tile(const Params &params, const int bidb, const int bidh, const int m_block, const BlockInfo</*Varlen=*/!Is_even_MN> &binfo) {
  21. // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.
  22. // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.
  23. // Otherwise, it's written as (h, b, seqlen_q).
  24. const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped;
  25. auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0;
  26. auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + lse_offset);
  27. auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q);
  28. auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : (
  29. params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1)
  30. );
  31. auto lse_layout = make_layout(lse_shape, lse_stride);
  32. Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout);
  33. auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _);
  34. return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block));
  35. }
  36. template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
  37. inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
  38. using Element = typename Kernel_traits::Element;
  39. using ElementAccum = typename Kernel_traits::ElementAccum;
  40. using index_t = typename Kernel_traits::index_t;
  41. // Shared memory.
  42. extern __shared__ char smem_[];
  43. // The thread index.
  44. const int tidx = threadIdx.x;
  45. constexpr int kBlockM = Kernel_traits::kBlockM;
  46. constexpr int kBlockN = Kernel_traits::kBlockN;
  47. constexpr int kHeadDim = Kernel_traits::kHeadDim;
  48. constexpr int kNWarps = Kernel_traits::kNWarps;
  49. auto seed_offset = at::cuda::philox::unpack(params.philox_args);
  50. flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
  51. bidb, bidh, tidx, params.h);
  52. // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might
  53. // exit early and no one saves the rng states.
  54. if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
  55. params.rng_state[0] = std::get<0>(seed_offset);
  56. params.rng_state[1] = std::get<1>(seed_offset);
  57. }
  58. const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
  59. if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
  60. const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
  61. int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
  62. if (Is_causal || Is_local) {
  63. n_block_max = std::min(n_block_max,
  64. cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
  65. // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
  66. // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
  67. // }
  68. }
  69. // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
  70. // Otherwise we might read OOB elements from gK and gV.
  71. if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
  72. Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
  73. + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
  74. make_shape(binfo.actual_seqlen_q, params.h, params.d),
  75. make_stride(params.o_row_stride, params.o_head_stride, _1{}));
  76. Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
  77. make_coord(m_block, 0)); // (kBlockM, kHeadDim)
  78. Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
  79. typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
  80. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
  81. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  82. Tensor tOrO = make_tensor<Element>(shape(tOgO));
  83. clear(tOrO);
  84. // Construct identity layout for sO
  85. Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  86. // Repeat the partitioning with identity layouts
  87. Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
  88. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
  89. if (!Is_even_K) {
  90. #pragma unroll
  91. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
  92. }
  93. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  94. flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  95. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
  96. );
  97. #pragma unroll
  98. for (int m = 0; m < size<1>(tOgO); ++m) {
  99. const int row = get<0>(tOcO(0, m, 0));
  100. if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
  101. }
  102. return;
  103. }
  104. // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
  105. // We iterate over the blocks in reverse order. This is because the last block is the only one
  106. // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
  107. // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
  108. const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
  109. + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
  110. Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)
  111. + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
  112. make_shape(binfo.actual_seqlen_q, params.h, params.d),
  113. make_stride(params.q_row_stride, params.q_head_stride, _1{}));
  114. Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
  115. make_coord(m_block, 0)); // (kBlockM, kHeadDim)
  116. Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)
  117. + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)),
  118. make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
  119. make_stride(params.k_row_stride, params.k_head_stride, _1{}));
  120. Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
  121. make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
  122. Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)
  123. + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)),
  124. make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
  125. make_stride(params.v_row_stride, params.v_head_stride, _1{}));
  126. Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
  127. make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
  128. Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
  129. Shape<Int<kBlockM>, Int<kBlockN>>{},
  130. make_stride(params.seqlen_k_rounded, _1{}));
  131. Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
  132. typename Kernel_traits::SmemLayoutQ{});
  133. // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
  134. Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
  135. typename Kernel_traits::SmemLayoutKV{});
  136. Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
  137. Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
  138. Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
  139. typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
  140. auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
  141. Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
  142. Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
  143. Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN)
  144. Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
  145. Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN)
  146. Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
  147. typename Kernel_traits::TiledMma tiled_mma;
  148. auto thr_mma = tiled_mma.get_thread_slice(tidx);
  149. Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
  150. Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
  151. Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
  152. Tensor tSgS = thr_mma.partition_C(gP);
  153. Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
  154. //
  155. // Copy Atom retiling
  156. //
  157. auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
  158. auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
  159. // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
  160. Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
  161. // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
  162. auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
  163. auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
  164. Tensor tSsK = smem_thr_copy_K.partition_S(sK);
  165. auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
  166. auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
  167. Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
  168. //
  169. // PREDICATES
  170. //
  171. // // Allocate predicate tensors for m and n
  172. // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
  173. // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
  174. // Construct identity layout for sQ and sK
  175. Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  176. Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
  177. // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
  178. // if (cute::thread0()) {
  179. // print(tScQ.layout()); printf("\n");
  180. // for (int i = 0; i < size(tScQ); ++i) {
  181. // printf("%d ", get<0>(tScQ(i)));
  182. // }
  183. // printf("\n");
  184. // for (int i = 0; i < size(tScQ); ++i) {
  185. // printf("%d ", get<1>(tScQ(i)));
  186. // }
  187. // printf("\n");
  188. // }
  189. // Repeat the partitioning with identity layouts
  190. Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  191. Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
  192. // Allocate predicate tensors for k
  193. Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
  194. Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
  195. // Set predicates for k bounds
  196. if (!Is_even_K) {
  197. #pragma unroll
  198. for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
  199. #pragma unroll
  200. for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
  201. }
  202. // Prologue
  203. // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
  204. flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
  205. binfo.actual_seqlen_q - m_block * kBlockM);
  206. if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
  207. // // if (cute::thread(1, 0)) { print(tQsQ); }
  208. // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
  209. // // if (cute::thread0()) { print(sQNoSwizzle); }
  210. if (Kernel_traits::Share_Q_K_smem) {
  211. flash::cp_async_wait<0>();
  212. __syncthreads();
  213. Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
  214. CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
  215. cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
  216. __syncthreads();
  217. }
  218. int n_block = n_block_max - 1;
  219. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
  220. flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
  221. binfo.actual_seqlen_k - n_block * kBlockN);
  222. cute::cp_async_fence();
  223. // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
  224. // __syncthreads();
  225. if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
  226. flash::cp_async_wait<1>();
  227. __syncthreads();
  228. Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
  229. CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
  230. cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
  231. }
  232. clear(acc_o);
  233. flash::Softmax<2 * size<1>(acc_o)> softmax;
  234. const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
  235. flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
  236. // For performance reason, we separate out two kinds of iterations:
  237. // those that need masking on S, and those that don't.
  238. // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
  239. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
  240. // We will have at least 1 "masking" iteration.
  241. // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
  242. // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
  243. constexpr int n_masking_steps = (!Is_causal && !Is_local)
  244. ? 1
  245. : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
  246. #pragma unroll
  247. for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
  248. Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
  249. clear(acc_s);
  250. flash::cp_async_wait<0>();
  251. __syncthreads();
  252. // Advance gV
  253. if (masking_step > 0) {
  254. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
  255. } else {
  256. // Clear the smem tiles to account for predicated off loads
  257. flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
  258. gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
  259. );
  260. }
  261. cute::cp_async_fence();
  262. flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
  263. acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
  264. smem_thr_copy_Q, smem_thr_copy_K
  265. );
  266. // if (cute::thread0()) { print(acc_s); }
  267. if constexpr (Is_softcap){
  268. flash::apply_softcap(acc_s, params.softcap);
  269. }
  270. mask.template apply_mask<Is_causal, Is_even_MN>(
  271. acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
  272. );
  273. flash::cp_async_wait<0>();
  274. __syncthreads();
  275. if (n_block > n_block_min) {
  276. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
  277. // This cp_async_fence needs to be in the if block, otherwise the synchronization
  278. // isn't right and we get race conditions.
  279. cute::cp_async_fence();
  280. }
  281. // TODO: when we have key_padding_mask we'll need to Check_inf
  282. masking_step == 0
  283. ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)
  284. : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
  285. // Convert acc_s from fp32 to fp16/bf16
  286. Tensor rP = flash::convert_type<Element>(acc_s);
  287. int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
  288. int block_col_idx = n_block * (kBlockN / 32);
  289. if (Return_softmax) {
  290. Tensor rP_drop = make_fragment_like(rP);
  291. cute::copy(rP, rP_drop);
  292. dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
  293. rP_drop, block_row_idx, block_col_idx, kNWarps
  294. );
  295. cute::copy(rP_drop, tSgS);
  296. tSgS.data() = tSgS.data() + (-kBlockN);
  297. }
  298. if (Is_dropout) {
  299. dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
  300. }
  301. // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  302. // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
  303. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
  304. // if (cute::thread0()) { print(tOrP); }
  305. flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
  306. // if (cute::thread0()) { print(scores); }
  307. // This check is at the end of the loop since we always have at least 1 iteration
  308. if (n_masking_steps > 1 && n_block <= n_block_min) {
  309. --n_block;
  310. break;
  311. }
  312. }
  313. // These are the iterations where we don't need masking on S
  314. for (; n_block >= n_block_min; --n_block) {
  315. Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
  316. clear(acc_s);
  317. flash::cp_async_wait<0>();
  318. __syncthreads();
  319. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
  320. cute::cp_async_fence();
  321. flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
  322. acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
  323. smem_thr_copy_Q, smem_thr_copy_K
  324. );
  325. if constexpr (Is_softcap){
  326. flash::apply_softcap(acc_s, params.softcap);
  327. }
  328. flash::cp_async_wait<0>();
  329. __syncthreads();
  330. if (n_block > n_block_min) {
  331. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
  332. // This cp_async_fence needs to be in the if block, otherwise the synchronization
  333. // isn't right and we get race conditions.
  334. cute::cp_async_fence();
  335. }
  336. mask.template apply_mask</*Causal_mask=*/false>(
  337. acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
  338. );
  339. softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
  340. Tensor rP = flash::convert_type<Element>(acc_s);
  341. int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
  342. int block_col_idx = n_block * (kBlockN / 32);
  343. if (Return_softmax) {
  344. Tensor rP_drop = make_fragment_like(rP);
  345. cute::copy(rP, rP_drop);
  346. dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
  347. rP_drop, block_row_idx, block_col_idx, kNWarps
  348. );
  349. cute::copy(rP_drop, tSgS);
  350. tSgS.data() = tSgS.data() + (-kBlockN);
  351. }
  352. if (Is_dropout) {
  353. dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
  354. }
  355. // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  356. // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
  357. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
  358. flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
  359. }
  360. // Epilogue
  361. Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
  362. // Convert acc_o from fp32 to fp16/bf16
  363. Tensor rO = flash::convert_type<Element>(acc_o);
  364. Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
  365. // Partition sO to match the accumulator partitioning
  366. auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
  367. auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
  368. Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
  369. Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  370. // sO has the same size as sQ, so we don't need to sync here.
  371. if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
  372. cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
  373. Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
  374. + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
  375. make_shape(binfo.actual_seqlen_q, params.h, params.d),
  376. make_stride(params.o_row_stride, params.o_head_stride, _1{}));
  377. Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
  378. make_coord(m_block, 0)); // (kBlockM, kHeadDim)
  379. Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
  380. typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
  381. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
  382. Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  383. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  384. __syncthreads();
  385. Tensor tOrO = make_tensor<Element>(shape(tOgO));
  386. cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
  387. Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  388. Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
  389. static_assert(decltype(size<0>(taccOcO))::value == 4);
  390. // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
  391. Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
  392. CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
  393. if (get<1>(taccOcO_row(0)) == 0) {
  394. #pragma unroll
  395. for (int mi = 0; mi < size(lse); ++mi) {
  396. const int row = get<0>(taccOcO_row(mi));
  397. if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
  398. }
  399. }
  400. // Construct identity layout for sO
  401. Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  402. // Repeat the partitioning with identity layouts
  403. Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  404. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
  405. if (!Is_even_K) {
  406. #pragma unroll
  407. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
  408. }
  409. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  410. flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  411. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
  412. );
  413. }
  414. ////////////////////////////////////////////////////////////////////////////////////////////////////
  415. template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params>
  416. inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
  417. using Element = typename Kernel_traits::Element;
  418. using ElementAccum = typename Kernel_traits::ElementAccum;
  419. using index_t = typename Kernel_traits::index_t;
  420. // Shared memory.
  421. extern __shared__ char smem_[];
  422. // The thread index.
  423. const int tidx = threadIdx.x;
  424. constexpr int kBlockM = Kernel_traits::kBlockM;
  425. constexpr int kBlockN = Kernel_traits::kBlockN;
  426. constexpr int kHeadDim = Kernel_traits::kHeadDim;
  427. constexpr int kNWarps = Kernel_traits::kNWarps;
  428. using GmemTiledCopyO = std::conditional_t<
  429. !Split,
  430. typename Kernel_traits::GmemTiledCopyO,
  431. typename Kernel_traits::GmemTiledCopyOaccum
  432. >;
  433. using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
  434. const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
  435. // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
  436. // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
  437. if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
  438. const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
  439. const int n_block_min = !Is_local
  440. ? n_split_idx * n_blocks_per_split
  441. : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
  442. int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
  443. if (Is_causal || Is_local) {
  444. n_block_max = std::min(n_block_max,
  445. cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
  446. }
  447. if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0
  448. // We exit early and write 0 to gOaccum and -inf to gLSEaccum.
  449. // Otherwise we might read OOB elements from gK and gV,
  450. // or get wrong results when we combine gOaccum from different blocks.
  451. const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
  452. + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
  453. const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
  454. + m_block * kBlockM) * params.d_rounded;
  455. const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
  456. Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
  457. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  458. make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
  459. Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
  460. Shape<Int<kBlockM>>{}, Stride<_1>{});
  461. GmemTiledCopyO gmem_tiled_copy_Oaccum;
  462. auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
  463. Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
  464. Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
  465. clear(tOrOaccum);
  466. // Construct identity layout for sO
  467. Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  468. // Repeat the partitioning with identity layouts
  469. Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
  470. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
  471. if (!Is_even_K) {
  472. #pragma unroll
  473. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
  474. }
  475. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  476. flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  477. gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
  478. );
  479. #pragma unroll
  480. for (int m = 0; m < size<1>(tOgOaccum); ++m) {
  481. const int row = get<0>(tOcO(0, m, 0));
  482. if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; }
  483. }
  484. return;
  485. }
  486. // We iterate over the blocks in reverse order. This is because the last block is the only one
  487. // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
  488. // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
  489. // We move K and V to the last block.
  490. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
  491. const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
  492. const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
  493. const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
  494. const index_t row_offset_k = block_table == nullptr
  495. ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
  496. + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
  497. : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
  498. const index_t row_offset_v = block_table == nullptr
  499. ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
  500. + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
  501. : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
  502. Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
  503. make_shape(binfo.actual_seqlen_q, params.h, params.d),
  504. make_stride(params.q_row_stride, params.q_head_stride, _1{}));
  505. Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
  506. make_coord(m_block, 0)); // (kBlockM, kHeadDim)
  507. Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
  508. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  509. make_stride(params.k_row_stride, _1{}));
  510. // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
  511. Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
  512. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  513. make_stride(params.v_row_stride, _1{}));
  514. Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
  515. typename Kernel_traits::SmemLayoutQ{});
  516. Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
  517. Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
  518. Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
  519. Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
  520. typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
  521. auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
  522. Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
  523. Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
  524. Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
  525. Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
  526. Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
  527. Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
  528. typename Kernel_traits::TiledMma tiled_mma;
  529. auto thr_mma = tiled_mma.get_thread_slice(tidx);
  530. Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
  531. Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
  532. Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
  533. Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
  534. //
  535. // Copy Atom retiling
  536. //
  537. auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
  538. auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
  539. Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
  540. auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
  541. auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
  542. Tensor tSsK = smem_thr_copy_K.partition_S(sK);
  543. auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
  544. auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
  545. Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
  546. // PREDICATES
  547. //
  548. // // Allocate predicate tensors for m and n
  549. // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
  550. // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
  551. // Construct identity layout for sQ and sK
  552. Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  553. Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
  554. // Repeat the partitioning with identity layouts
  555. Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  556. Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
  557. // Allocate predicate tensors for k
  558. Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
  559. Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
  560. // Set predicates for k bounds
  561. if (!Is_even_K) {
  562. #pragma unroll
  563. for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
  564. #pragma unroll
  565. for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
  566. }
  567. // Prologue
  568. // Copy from Knew to K, optionally apply rotary embedding.
  569. typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
  570. auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
  571. typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
  572. auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
  573. if constexpr (Append_KV) {
  574. // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
  575. // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
  576. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
  577. const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);
  578. Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
  579. Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
  580. make_stride(params.rotary_dim / 2, _1{}));
  581. Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
  582. Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
  583. make_stride(params.rotary_dim / 2, _1{}));
  584. Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
  585. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  586. make_stride(params.rotary_dim / 2, _1{}));
  587. Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
  588. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  589. make_stride(params.rotary_dim / 2, _1{}));
  590. Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
  591. Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
  592. Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
  593. Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
  594. // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
  595. // if (cute::thread(8, 0)) { print_tensor(gCos); }
  596. // if (cute::thread(0, 0)) { print_tensor(tRgCos); }
  597. // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
  598. const index_t row_offset_knew = bidb * params.knew_batch_stride
  599. + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
  600. // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
  601. const index_t row_offset_vnew = bidb * params.vnew_batch_stride
  602. + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
  603. // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
  604. // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
  605. // This maps to accessing the first 64 rows of knew_ptr.
  606. Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
  607. + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
  608. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  609. make_stride(params.knew_row_stride, _1{}));
  610. // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
  611. Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
  612. + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
  613. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  614. make_stride(params.vnew_row_stride, _1{}));
  615. Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
  616. Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
  617. const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
  618. auto tKgK_data = tKgK.data();
  619. auto tVgV_data = tVgV.data();
  620. for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
  621. flash::copy_w_min_idx<Is_even_K>(
  622. tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
  623. );
  624. tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
  625. if (params.rotary_dim == 0) {
  626. flash::copy_w_min_idx<Is_even_K>(
  627. tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
  628. );
  629. } else {
  630. if (params.is_rotary_interleaved) {
  631. // Don't clear OOB_K because we're writing to global memory
  632. flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
  633. tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
  634. binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
  635. );
  636. tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
  637. tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
  638. } else {
  639. // Don't clear OOB_K because we're writing to global memory
  640. flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
  641. tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
  642. binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
  643. );
  644. tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
  645. tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
  646. }
  647. }
  648. tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
  649. if (block_table == nullptr) {
  650. tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
  651. tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
  652. } else {
  653. if (n_block > n_block_copy_min) {
  654. const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
  655. const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
  656. const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
  657. const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
  658. const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
  659. const int offset_diff = block_table_offset_next - block_table_offset_cur;
  660. tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
  661. tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
  662. }
  663. }
  664. }
  665. // Need this before we can read in K again, so that we'll see the updated K values.
  666. __syncthreads();
  667. tKgK.data() = tKgK_data;
  668. tVgV.data() = tVgV_data;
  669. }
  670. // Read Q from gmem to smem, optionally apply rotary embedding.
  671. if (!Append_KV || params.rotary_dim == 0) {
  672. // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
  673. flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
  674. binfo.actual_seqlen_q - m_block * kBlockM);
  675. } else {
  676. const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
  677. // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
  678. // We do this by setting the row stride of gCos / gSin to 0.
  679. Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
  680. Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
  681. make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
  682. Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
  683. Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
  684. make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
  685. Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
  686. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  687. make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
  688. Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
  689. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  690. make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
  691. Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
  692. Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
  693. Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
  694. Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
  695. if (params.is_rotary_interleaved) {
  696. flash::copy_rotary_interleaved<Is_even_K>(
  697. tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
  698. 0, params.d, params.rotary_dim
  699. );
  700. } else {
  701. flash::copy_rotary_contiguous<Is_even_K>(
  702. tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
  703. 0, params.d, params.rotary_dim
  704. );
  705. }
  706. }
  707. int n_block = n_block_max - 1;
  708. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
  709. flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
  710. binfo.actual_seqlen_k - n_block * kBlockN);
  711. cute::cp_async_fence();
  712. // flash::cp_async_wait<0>();
  713. // __syncthreads();
  714. // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
  715. // __syncthreads();
  716. clear(acc_o);
  717. flash::Softmax<2 * size<1>(acc_o)> softmax;
  718. const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
  719. flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
  720. // For performance reason, we separate out two kinds of iterations:
  721. // those that need masking on S, and those that don't.
  722. // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
  723. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
  724. // We will have at least 1 "masking" iteration.
  725. // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
  726. // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
  727. constexpr int n_masking_steps = (!Is_causal && !Is_local)
  728. ? 1
  729. : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
  730. #pragma unroll
  731. for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
  732. Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
  733. clear(acc_s);
  734. flash::cp_async_wait<0>();
  735. __syncthreads();
  736. // Advance gV
  737. if (masking_step > 0) {
  738. if (block_table == nullptr) {
  739. tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
  740. } else {
  741. const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
  742. const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
  743. const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
  744. const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
  745. tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
  746. }
  747. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
  748. } else {
  749. // Clear the smem tiles to account for predicated off loads
  750. flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
  751. gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
  752. );
  753. }
  754. cute::cp_async_fence();
  755. flash::gemm(
  756. acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
  757. smem_thr_copy_Q, smem_thr_copy_K
  758. );
  759. // if (cute::thread0()) { print(acc_s); }
  760. if constexpr (Is_softcap){
  761. flash::apply_softcap(acc_s, params.softcap);
  762. }
  763. mask.template apply_mask<Is_causal, Is_even_MN>(
  764. acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
  765. );
  766. flash::cp_async_wait<0>();
  767. __syncthreads();
  768. // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
  769. // __syncthreads();
  770. if (n_block > n_block_min) {
  771. // Advance gK
  772. if (block_table == nullptr) {
  773. tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
  774. } else {
  775. const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
  776. const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
  777. const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
  778. const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
  779. tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
  780. }
  781. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
  782. // This cp_async_fence needs to be in the if block, otherwise the synchronization
  783. // isn't right and we get race conditions.
  784. cute::cp_async_fence();
  785. }
  786. // We have key_padding_mask so we'll need to Check_inf
  787. masking_step == 0
  788. ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2)
  789. : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2);
  790. // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
  791. // Convert acc_s from fp32 to fp16/bf16
  792. Tensor rP = flash::convert_type<Element>(acc_s);
  793. // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  794. // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
  795. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
  796. flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
  797. // This check is at the end of the loop since we always have at least 1 iteration
  798. if (n_masking_steps > 1 && n_block <= n_block_min) {
  799. --n_block;
  800. break;
  801. }
  802. }
  803. // These are the iterations where we don't need masking on S
  804. for (; n_block >= n_block_min; --n_block) {
  805. Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
  806. clear(acc_s);
  807. flash::cp_async_wait<0>();
  808. __syncthreads();
  809. // Advance gV
  810. if (block_table == nullptr) {
  811. tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
  812. } else {
  813. const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
  814. const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
  815. const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
  816. const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
  817. tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
  818. }
  819. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
  820. cute::cp_async_fence();
  821. flash::gemm(
  822. acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
  823. smem_thr_copy_Q, smem_thr_copy_K
  824. );
  825. if constexpr (Is_softcap){
  826. flash::apply_softcap(acc_s, params.softcap);
  827. }
  828. flash::cp_async_wait<0>();
  829. __syncthreads();
  830. if (n_block > n_block_min) {
  831. // Advance gK
  832. if (block_table == nullptr) {
  833. tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
  834. } else {
  835. const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
  836. const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
  837. const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
  838. const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
  839. tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
  840. }
  841. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
  842. // This cp_async_fence needs to be in the if block, otherwise the synchronization
  843. // isn't right and we get race conditions.
  844. cute::cp_async_fence();
  845. }
  846. mask.template apply_mask</*Causal_mask=*/false>(
  847. acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
  848. );
  849. softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
  850. Tensor rP = flash::convert_type<Element>(acc_s);
  851. // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  852. // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
  853. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
  854. flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
  855. }
  856. // Epilogue
  857. Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, params.scale_softmax);
  858. // if (cute::thread0()) { print(lse); }
  859. Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
  860. // Partition sO to match the accumulator partitioning
  861. using SmemTiledCopyO = std::conditional_t<
  862. !Split,
  863. typename Kernel_traits::SmemCopyAtomO,
  864. typename Kernel_traits::SmemCopyAtomOaccum
  865. >;
  866. auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
  867. auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
  868. Tensor rO = flash::convert_type<ElementO>(acc_o);
  869. Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
  870. Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  871. // sOaccum is larger than sQ, so we need to syncthreads here
  872. // TODO: allocate enough smem for sOaccum
  873. if constexpr (Split) { __syncthreads(); }
  874. cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
  875. const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
  876. + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
  877. const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
  878. + m_block * kBlockM) * params.d_rounded;
  879. const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ?
  880. ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)
  881. ) + m_block * kBlockM;
  882. Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
  883. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  884. make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
  885. Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
  886. Shape<Int<kBlockM>>{}, Stride<_1>{});
  887. // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); }
  888. GmemTiledCopyO gmem_tiled_copy_Oaccum;
  889. auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
  890. Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  891. Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
  892. __syncthreads();
  893. Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
  894. cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
  895. Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  896. Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
  897. static_assert(decltype(size<0>(taccOcO))::value == 4);
  898. // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
  899. Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
  900. CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
  901. if (get<1>(taccOcO_row(0)) == 0) {
  902. #pragma unroll
  903. for (int mi = 0; mi < size(lse); ++mi) {
  904. const int row = get<0>(taccOcO_row(mi));
  905. if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
  906. }
  907. }
  908. // Construct identity layout for sO
  909. Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  910. // Repeat the partitioning with identity layouts
  911. Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  912. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
  913. if (!Is_even_K) {
  914. #pragma unroll
  915. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
  916. }
  917. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  918. flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  919. gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
  920. );
  921. }
  922. ////////////////////////////////////////////////////////////////////////////////////////////////////
  923. template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
  924. inline __device__ void compute_attn(const Params &params) {
  925. const int m_block = blockIdx.x;
  926. // The block index for the batch.
  927. const int bidb = blockIdx.y;
  928. // The block index for the head.
  929. const int bidh = blockIdx.z;
  930. // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
  931. // them to have the same number of threads or have to traverse the attention matrix
  932. // in the same order.
  933. // In the Philox RNG, we use the offset to store the batch, head, and the lane id
  934. // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
  935. // the attention matrix. This way, as long as we have the batch, head, and the location of
  936. // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
  937. flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
  938. }
  939. ////////////////////////////////////////////////////////////////////////////////////////////////////
  940. template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params>
  941. inline __device__ void compute_attn_splitkv(const Params &params) {
  942. const int m_block = blockIdx.x;
  943. // The block index for the batch.
  944. const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
  945. // The block index for the head.
  946. const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
  947. const int n_split_idx = Split ? blockIdx.y : 0;
  948. const int num_n_splits = Split ? gridDim.y : 1;
  949. flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
  950. }
  951. ////////////////////////////////////////////////////////////////////////////////////////////////////
  952. template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
  953. inline __device__ void combine_attn_seqk_parallel(const Params &params) {
  954. using Element = typename Kernel_traits::Element;
  955. using ElementAccum = typename Kernel_traits::ElementAccum;
  956. using index_t = typename Kernel_traits::index_t;
  957. constexpr int kMaxSplits = 1 << Log_max_splits;
  958. constexpr int kHeadDim = Kernel_traits::kHeadDim;
  959. constexpr int kNThreads = Kernel_traits::kNThreads;
  960. static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
  961. static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32");
  962. static_assert(kNThreads == 128, "We assume that each block has 128 threads");
  963. // Shared memory.
  964. // kBlockM + 1 instead of kBlockM to reduce bank conflicts.
  965. __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1];
  966. // The thread and block index.
  967. const int tidx = threadIdx.x;
  968. const int bidx = blockIdx.x;
  969. const index_t lse_size = params.b * params.h * params.seqlen_q;
  970. const index_t row_offset_lse = bidx * kBlockM;
  971. Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
  972. Shape<Int<kMaxSplits>, Int<kBlockM>>{},
  973. make_stride(lse_size, _1{}));
  974. // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
  975. // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
  976. Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
  977. Shape<Int<kBlockM>>{}, Stride<_1>{});
  978. // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
  979. Layout flat_layout = make_layout(lse_size);
  980. Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b));
  981. auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
  982. Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
  983. Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));
  984. Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout);
  985. constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
  986. // Read the LSE values from gmem and store them in shared memory, then transpose them.
  987. constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
  988. #pragma unroll
  989. for (int l = 0; l < kNLsePerThread; ++l) {
  990. const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
  991. const int col = tidx % kBlockM;
  992. ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
  993. if (row < kMaxSplits) { sLSE[row][col] = lse; }
  994. // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
  995. }
  996. // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
  997. __syncthreads();
  998. Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
  999. constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
  1000. // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
  1001. // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
  1002. // kBlockM rows, so each time we load we can load 128 / kBlockM rows).
  1003. // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
  1004. // static_assert(kThreadsPerSplit <= 32);
  1005. static_assert(kRowsPerLoadTranspose <= 32);
  1006. static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
  1007. #pragma unroll
  1008. for (int l = 0; l < kNLsePerThread; ++l) {
  1009. const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
  1010. const int col = tidx / kRowsPerLoadTranspose;
  1011. lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY;
  1012. // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
  1013. }
  1014. // Compute the logsumexp of the LSE along the split dimension.
  1015. ElementAccum lse_max = lse_accum(0);
  1016. #pragma unroll
  1017. for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
  1018. MaxOp<float> max_op;
  1019. lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
  1020. lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
  1021. float lse_sum = expf(lse_accum(0) - lse_max);
  1022. #pragma unroll
  1023. for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
  1024. SumOp<float> sum_op;
  1025. lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
  1026. // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
  1027. // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
  1028. ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
  1029. // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
  1030. if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
  1031. if (params.unpadded_lse) {
  1032. const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
  1033. if (lse_offset < lse_size) {
  1034. gLSE_unpadded(lse_offset) = lse_logsum;
  1035. }
  1036. } else {
  1037. gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
  1038. }
  1039. }
  1040. // Store the scales exp(lse - lse_logsum) in shared memory.
  1041. #pragma unroll
  1042. for (int l = 0; l < kNLsePerThread; ++l) {
  1043. const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
  1044. const int col = tidx / kRowsPerLoadTranspose;
  1045. if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); }
  1046. }
  1047. __syncthreads();
  1048. const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
  1049. Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
  1050. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  1051. Stride<Int<kHeadDim>, _1>{});
  1052. constexpr int kBlockN = kNThreads / kBlockM;
  1053. using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
  1054. using GmemTiledCopyOaccum = decltype(
  1055. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  1056. GmemLayoutAtomOaccum{},
  1057. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  1058. GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
  1059. auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
  1060. Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
  1061. Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
  1062. Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
  1063. clear(tOrO);
  1064. // Predicates
  1065. Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
  1066. // Repeat the partitioning with identity layouts
  1067. Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
  1068. Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
  1069. if (!Is_even_K) {
  1070. #pragma unroll
  1071. for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; }
  1072. }
  1073. // Load Oaccum in then scale and accumulate to O
  1074. for (int split = 0; split < params.num_splits; ++split) {
  1075. flash::copy</*Is_even_MN=*/false, Is_even_K>(
  1076. gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM
  1077. );
  1078. #pragma unroll
  1079. for (int m = 0; m < size<1>(tOrOaccum); ++m) {
  1080. int row = get<0>(tOcOaccum(0, m, 0));
  1081. ElementAccum lse_scale = sLSE[split][row];
  1082. #pragma unroll
  1083. for (int k = 0; k < size<2>(tOrOaccum); ++k) {
  1084. #pragma unroll
  1085. for (int i = 0; i < size<0>(tOrOaccum); ++i) {
  1086. tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
  1087. }
  1088. }
  1089. // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); }
  1090. }
  1091. tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;
  1092. }
  1093. // if (cute::thread0()) { print_tensor(tOrO); }
  1094. Tensor rO = flash::convert_type<Element>(tOrO);
  1095. // Write to gO
  1096. #pragma unroll
  1097. for (int m = 0; m < size<1>(rO); ++m) {
  1098. const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
  1099. if (idx < params.b * params.h * params.seqlen_q) {
  1100. const int batch_idx = idx / (params.h * params.seqlen_q);
  1101. const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
  1102. // The index to the rows of Q
  1103. const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q;
  1104. auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride
  1105. + head_idx * params.o_head_stride + row * params.o_row_stride;
  1106. #pragma unroll
  1107. for (int k = 0; k < size<2>(rO); ++k) {
  1108. if (Is_even_K || tOpOaccum(k)) {
  1109. const int col = get<1>(tOcOaccum(0, m, k));
  1110. Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
  1111. Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
  1112. // TODO: Should check if this is using vectorized store, but it seems pretty fast
  1113. copy(rO(_, m, k), gO);
  1114. // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
  1115. // reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k);
  1116. }
  1117. }
  1118. }
  1119. }
  1120. }
  1121. } // namespace flash