flash_fwd_kernel.h 74 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cute/tensor.hpp>
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/array.h>
  8. #include <cutlass/numeric_types.h>
  9. #include "block_info.h"
  10. #include "kernel_traits.h"
  11. #include "utils.h"
  12. #include "softmax.h"
  13. #include "mask.h"
  14. #include "dropout.h"
  15. #include "rotary.h"
  16. namespace flash {
  17. using namespace cute;
  18. template <typename Engine, typename Layout>
  19. __forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout>& tensor,
  20. const float softcap) {
  21. #pragma unroll
  22. for (int i = 0; i < size(tensor); ++i) {
  23. tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
  24. }
  25. }
  26. ////////////////////////////////////////////////////////////////////////////////////////////////////
  27. template <typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
  28. __forceinline__ __device__ auto get_lse_tile(
  29. const Params& params, const int bidb, const int bidh, const int m_block,
  30. const BlockInfo</*Varlen=*/!Is_even_MN>& binfo) {
  31. // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) -
  32. // this is non-variable seqlen path. Otherwise, when
  33. // params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b)
  34. // to account for seqlen_q <-> h swapping trick. Otherwise, it's written as
  35. // (h, b, seqlen_q).
  36. const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped;
  37. auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0;
  38. auto gmem_ptr_lse = make_gmem_ptr(
  39. reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + lse_offset);
  40. auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q)
  41. : make_shape(params.b, params.h, params.seqlen_q);
  42. auto lse_stride =
  43. params.seqlenq_ngroups_swapped
  44. ? make_stride(1, params.seqlen_q * params.b, params.b)
  45. : (params.unpadded_lse
  46. ? make_stride(params.h * params.total_q, params.total_q, 1)
  47. : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1));
  48. auto lse_layout = make_layout(lse_shape, lse_stride);
  49. Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout);
  50. auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _);
  51. return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block));
  52. }
  53. template <typename Kernel_traits, bool Is_dropout, bool Is_causal,
  54. bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K,
  55. bool Is_softcap, bool Return_softmax, typename Params>
  56. inline __device__ void compute_attn_1rowblock(const Params& params,
  57. const int bidb, const int bidh,
  58. const int m_block) {
  59. using Element = typename Kernel_traits::Element;
  60. using ElementAccum = typename Kernel_traits::ElementAccum;
  61. using index_t = typename Kernel_traits::index_t;
  62. // Shared memory.
  63. extern __shared__ char smem_[];
  64. // The thread index.
  65. const int tidx = threadIdx.x;
  66. constexpr int kBlockM = Kernel_traits::kBlockM;
  67. constexpr int kBlockN = Kernel_traits::kBlockN;
  68. constexpr int kHeadDim = Kernel_traits::kHeadDim;
  69. constexpr int kNWarps = Kernel_traits::kNWarps;
  70. auto seed_offset = at::cuda::philox::unpack(params.philox_args);
  71. flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset),
  72. params.p_dropout_in_uint8_t, bidb, bidh, tidx,
  73. params.h);
  74. // Save seed and offset for backward, before any early exiting. Otherwise the
  75. // 0-th thread block might exit early and no one saves the rng states.
  76. if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&
  77. tidx == 0) {
  78. params.rng_state[0] = std::get<0>(seed_offset);
  79. params.rng_state[1] = std::get<1>(seed_offset);
  80. }
  81. const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
  82. if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
  83. const int n_block_min =
  84. !Is_local
  85. ? 0
  86. : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k -
  87. binfo.actual_seqlen_q - params.window_size_left) /
  88. kBlockN);
  89. int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
  90. if (Is_causal || Is_local) {
  91. n_block_max = std::min(
  92. n_block_max,
  93. cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k -
  94. binfo.actual_seqlen_q + params.window_size_right,
  95. kBlockN));
  96. // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
  97. // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
  98. // }
  99. }
  100. // We exit early and write 0 to gO and gLSE. This also covers the case where
  101. // actual_seqlen_k == 0. Otherwise we might read OOB elements from gK and gV.
  102. if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
  103. Tensor mO = make_tensor(
  104. make_gmem_ptr(
  105. reinterpret_cast<Element*>(params.o_ptr) +
  106. binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
  107. make_shape(binfo.actual_seqlen_q, params.h, params.d),
  108. make_stride(params.o_row_stride, params.o_head_stride, _1{}));
  109. Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
  110. make_coord(m_block, 0)); // (kBlockM, kHeadDim)
  111. Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(
  112. params, bidb, bidh, m_block, binfo);
  113. typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
  114. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
  115. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  116. Tensor tOrO = make_tensor<Element>(shape(tOgO));
  117. clear(tOrO);
  118. // Construct identity layout for sO
  119. Tensor cO = make_identity_tensor(make_shape(
  120. size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  121. // Repeat the partitioning with identity layouts
  122. Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
  123. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
  124. if (!Is_even_K) {
  125. #pragma unroll
  126. for (int k = 0; k < size(tOpO); ++k) {
  127. tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
  128. }
  129. }
  130. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  131. flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false,
  132. /*Clear_OOB_K=*/false>(
  133. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO,
  134. binfo.actual_seqlen_q - m_block * kBlockM);
  135. #pragma unroll
  136. for (int m = 0; m < size<1>(tOgO); ++m) {
  137. const int row = get<0>(tOcO(0, m, 0));
  138. if (row < binfo.actual_seqlen_q - m_block * kBlockM &&
  139. get<1>(tOcO(0, m, 0)) == 0) {
  140. gLSE(row) = INFINITY;
  141. }
  142. }
  143. return;
  144. }
  145. // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max =
  146. // %d\n", m_block, n_block_min, n_block_max); }
  147. // We iterate over the blocks in reverse order. This is because the last block
  148. // is the only one that needs masking when we read K and V from global memory.
  149. // Moreover, iterating in reverse might save us 1 register (we just need
  150. // n_block instead of both n_block and n_block_max).
  151. const index_t row_offset_p =
  152. ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) *
  153. params.seqlen_k_rounded +
  154. (n_block_max - 1) * kBlockN;
  155. Tensor mQ =
  156. make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) +
  157. binfo.q_offset(params.q_batch_stride,
  158. params.q_row_stride, bidb)),
  159. make_shape(binfo.actual_seqlen_q, params.h, params.d),
  160. make_stride(params.q_row_stride, params.q_head_stride, _1{}));
  161. Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
  162. make_coord(m_block, 0)); // (kBlockM, kHeadDim)
  163. Tensor mK =
  164. make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr) +
  165. binfo.k_offset(params.k_batch_stride,
  166. params.k_row_stride, bidb)),
  167. make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
  168. make_stride(params.k_row_stride, params.k_head_stride, _1{}));
  169. Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _),
  170. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  171. make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
  172. Tensor mV =
  173. make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr) +
  174. binfo.k_offset(params.v_batch_stride,
  175. params.v_row_stride, bidb)),
  176. make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
  177. make_stride(params.v_row_stride, params.v_head_stride, _1{}));
  178. Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _),
  179. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  180. make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
  181. Tensor gP = make_tensor(
  182. make_gmem_ptr(reinterpret_cast<Element*>(params.p_ptr) + row_offset_p),
  183. Shape<Int<kBlockM>, Int<kBlockN>>{},
  184. make_stride(params.seqlen_k_rounded, _1{}));
  185. Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element*>(smem_)),
  186. typename Kernel_traits::SmemLayoutQ{});
  187. // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
  188. Tensor sK =
  189. make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
  190. typename Kernel_traits::SmemLayoutKV{});
  191. Tensor sV =
  192. make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
  193. Tensor sVt =
  194. make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
  195. Tensor sVtNoSwizzle =
  196. make_tensor(sV.data().get(),
  197. typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
  198. typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
  199. auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
  200. Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
  201. Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
  202. Tensor tKgK =
  203. gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN)
  204. Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
  205. Tensor tVgV =
  206. gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN)
  207. Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
  208. typename Kernel_traits::TiledMma tiled_mma;
  209. auto thr_mma = tiled_mma.get_thread_slice(tidx);
  210. Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
  211. Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
  212. Tensor tOrVt =
  213. thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
  214. Tensor tSgS = thr_mma.partition_C(gP);
  215. Tensor acc_o = partition_fragment_C(
  216. tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
  217. //
  218. // Copy Atom retiling
  219. //
  220. auto smem_tiled_copy_Q =
  221. make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
  222. auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
  223. // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
  224. Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
  225. // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
  226. auto smem_tiled_copy_K =
  227. make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
  228. auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
  229. Tensor tSsK = smem_thr_copy_K.partition_S(sK);
  230. auto smem_tiled_copy_V = make_tiled_copy_B(
  231. typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
  232. auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
  233. Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
  234. //
  235. // PREDICATES
  236. //
  237. // // Allocate predicate tensors for m and n
  238. // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)),
  239. // Stride<_1,_0>{}); Tensor tKVpKV =
  240. // make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)),
  241. // Stride<_1,_0>{});
  242. // Construct identity layout for sQ and sK
  243. Tensor cQ = make_identity_tensor(
  244. make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  245. Tensor cKV = make_identity_tensor(
  246. make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
  247. // Tensor tScQ = thr_mma.partition_A(cQ); //
  248. // (MMA,MMA_M,MMA_K) if (cute::thread0()) {
  249. // print(tScQ.layout()); printf("\n");
  250. // for (int i = 0; i < size(tScQ); ++i) {
  251. // printf("%d ", get<0>(tScQ(i)));
  252. // }
  253. // printf("\n");
  254. // for (int i = 0; i < size(tScQ); ++i) {
  255. // printf("%d ", get<1>(tScQ(i)));
  256. // }
  257. // printf("\n");
  258. // }
  259. // Repeat the partitioning with identity layouts
  260. Tensor tQcQ = gmem_thr_copy_QKV.partition_S(
  261. cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  262. Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(
  263. cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
  264. // Allocate predicate tensors for k
  265. Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
  266. Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
  267. // Set predicates for k bounds
  268. if (!Is_even_K) {
  269. #pragma unroll
  270. for (int k = 0; k < size(tQpQ); ++k) {
  271. tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d;
  272. }
  273. #pragma unroll
  274. for (int k = 0; k < size(tKVpKV); ++k) {
  275. tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
  276. }
  277. }
  278. // Prologue
  279. // We don't need to clear the sQ smem tiles since we'll only write out the
  280. // valid outputs
  281. flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ,
  282. tQpQ,
  283. binfo.actual_seqlen_q - m_block * kBlockM);
  284. if (Kernel_traits::Is_Q_in_regs) {
  285. cute::cp_async_fence();
  286. }
  287. // // if (cute::thread(1, 0)) { print(tQsQ); }
  288. // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element
  289. // *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
  290. // // if (cute::thread0()) { print(sQNoSwizzle); }
  291. if (Kernel_traits::Share_Q_K_smem) {
  292. flash::cp_async_wait<0>();
  293. __syncthreads();
  294. Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
  295. CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
  296. cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
  297. __syncthreads();
  298. }
  299. int n_block = n_block_max - 1;
  300. // We don't need to clear the sK smem tiles since we'll mask out the scores
  301. // anyway.
  302. flash::copy<Is_even_MN, Is_even_K>(
  303. gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
  304. binfo.actual_seqlen_k - n_block * kBlockN);
  305. cute::cp_async_fence();
  306. // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
  307. // __syncthreads();
  308. if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
  309. flash::cp_async_wait<1>();
  310. __syncthreads();
  311. Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
  312. CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
  313. cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
  314. }
  315. clear(acc_o);
  316. flash::Softmax<2 * size<1>(acc_o)> softmax;
  317. const float alibi_slope =
  318. !Has_alibi || params.alibi_slopes_ptr == nullptr
  319. ? 0.0f
  320. : reinterpret_cast<float*>(params.alibi_slopes_ptr)
  321. [bidb * params.alibi_slopes_batch_stride + bidh] /
  322. params.scale_softmax;
  323. flash::Mask<Is_causal, Is_local, Has_alibi> mask(
  324. binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left,
  325. params.window_size_right, alibi_slope);
  326. // For performance reason, we separate out two kinds of iterations:
  327. // those that need masking on S, and those that don't.
  328. // We need masking on S for the very last block when K and V has length not
  329. // multiple of kBlockN. We also need masking on S if it's causal, for the last
  330. // ceil_div(kBlockM, kBlockN) blocks. We will have at least 1 "masking"
  331. // iteration.
  332. // If not even_N, then seqlen_k might end in the middle of a block. In that
  333. // case we need to mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
  334. constexpr int n_masking_steps =
  335. (!Is_causal && !Is_local)
  336. ? 1
  337. : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN)
  338. : cute::ceil_div(kBlockM, kBlockN) + 1);
  339. #pragma unroll
  340. for (int masking_step = 0; masking_step < n_masking_steps;
  341. ++masking_step, --n_block) {
  342. Tensor acc_s = partition_fragment_C(
  343. tiled_mma,
  344. Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
  345. clear(acc_s);
  346. flash::cp_async_wait<0>();
  347. __syncthreads();
  348. // Advance gV
  349. if (masking_step > 0) {
  350. flash::copy</*Is_even_MN=*/true, Is_even_K>(
  351. gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
  352. } else {
  353. // Clear the smem tiles to account for predicated off loads
  354. flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
  355. gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV,
  356. binfo.actual_seqlen_k - n_block * kBlockN);
  357. }
  358. cute::cp_async_fence();
  359. flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
  360. acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q,
  361. smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K);
  362. // if (cute::thread0()) { print(acc_s); }
  363. if constexpr (Is_softcap) {
  364. apply_softcap(acc_s, params.softcap);
  365. }
  366. mask.template apply_mask<Is_causal, Is_even_MN>(
  367. acc_s, n_block * kBlockN,
  368. m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
  369. flash::cp_async_wait<0>();
  370. __syncthreads();
  371. if (n_block > n_block_min) {
  372. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV,
  373. tKgK(_, _, _, n_block - 1),
  374. tKsK, tKVcKV, tKVpKV);
  375. // This cp_async_fence needs to be in the if block, otherwise the
  376. // synchronization isn't right and we get race conditions.
  377. cute::cp_async_fence();
  378. }
  379. // TODO: when we have key_padding_mask we'll need to Check_inf
  380. masking_step == 0
  381. ? softmax.template softmax_rescale_o<
  382. /*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(
  383. acc_s, acc_o, params.scale_softmax_log2)
  384. : softmax.template softmax_rescale_o<
  385. /*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(
  386. acc_s, acc_o, params.scale_softmax_log2);
  387. // Convert acc_s from fp32 to fp16/bf16
  388. Tensor rP = flash::convert_type<Element>(acc_s);
  389. int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
  390. int block_col_idx = n_block * (kBlockN / 32);
  391. if (Return_softmax) {
  392. Tensor rP_drop = make_fragment_like(rP);
  393. cute::copy(rP, rP_drop);
  394. dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
  395. rP_drop, block_row_idx, block_col_idx, kNWarps);
  396. cute::copy(rP_drop, tSgS);
  397. tSgS.data() = tSgS.data() + (-kBlockN);
  398. }
  399. if (Is_dropout) {
  400. dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
  401. }
  402. // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  403. // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
  404. Tensor tOrP = make_tensor(
  405. rP.data(),
  406. flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
  407. // if (cute::thread0()) { print(tOrP); }
  408. flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V,
  409. smem_thr_copy_V);
  410. // if (cute::thread0()) { print(scores); }
  411. // This check is at the end of the loop since we always have at least 1
  412. // iteration
  413. if (n_masking_steps > 1 && n_block <= n_block_min) {
  414. --n_block;
  415. break;
  416. }
  417. }
  418. // These are the iterations where we don't need masking on S
  419. for (; n_block >= n_block_min; --n_block) {
  420. Tensor acc_s = partition_fragment_C(
  421. tiled_mma,
  422. Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
  423. clear(acc_s);
  424. flash::cp_async_wait<0>();
  425. __syncthreads();
  426. flash::copy</*Is_even_MN=*/true, Is_even_K>(
  427. gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
  428. cute::cp_async_fence();
  429. flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
  430. acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q,
  431. smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K);
  432. if constexpr (Is_softcap) {
  433. apply_softcap(acc_s, params.softcap);
  434. }
  435. flash::cp_async_wait<0>();
  436. __syncthreads();
  437. if (n_block > n_block_min) {
  438. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV,
  439. tKgK(_, _, _, n_block - 1),
  440. tKsK, tKVcKV, tKVpKV);
  441. // This cp_async_fence needs to be in the if block, otherwise the
  442. // synchronization isn't right and we get race conditions.
  443. cute::cp_async_fence();
  444. }
  445. mask.template apply_mask</*Causal_mask=*/false>(
  446. acc_s, n_block * kBlockN,
  447. m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
  448. softmax
  449. .template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(
  450. acc_s, acc_o, params.scale_softmax_log2);
  451. Tensor rP = flash::convert_type<Element>(acc_s);
  452. int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
  453. int block_col_idx = n_block * (kBlockN / 32);
  454. if (Return_softmax) {
  455. Tensor rP_drop = make_fragment_like(rP);
  456. cute::copy(rP, rP_drop);
  457. dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
  458. rP_drop, block_row_idx, block_col_idx, kNWarps);
  459. cute::copy(rP_drop, tSgS);
  460. tSgS.data() = tSgS.data() + (-kBlockN);
  461. }
  462. if (Is_dropout) {
  463. dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
  464. }
  465. // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  466. // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
  467. Tensor tOrP = make_tensor(
  468. rP.data(),
  469. flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
  470. flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V,
  471. smem_thr_copy_V);
  472. }
  473. // Epilogue
  474. Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(
  475. acc_o, params.scale_softmax, params.rp_dropout);
  476. // Convert acc_o from fp32 to fp16/bf16
  477. Tensor rO = flash::convert_type<Element>(acc_o);
  478. Tensor sO = make_tensor(
  479. sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
  480. // Partition sO to match the accumulator partitioning
  481. auto smem_tiled_copy_O =
  482. make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
  483. auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
  484. Tensor taccOrO =
  485. smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
  486. Tensor taccOsO =
  487. smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  488. // sO has the same size as sQ, so we don't need to sync here.
  489. if (Kernel_traits::Share_Q_K_smem) {
  490. __syncthreads();
  491. }
  492. cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
  493. Tensor mO =
  494. make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr) +
  495. binfo.q_offset(params.o_batch_stride,
  496. params.o_row_stride, bidb)),
  497. make_shape(binfo.actual_seqlen_q, params.h, params.d),
  498. make_stride(params.o_row_stride, params.o_head_stride, _1{}));
  499. Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
  500. make_coord(m_block, 0)); // (kBlockM, kHeadDim)
  501. Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(
  502. params, bidb, bidh, m_block, binfo);
  503. typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
  504. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
  505. Tensor tOsO =
  506. gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  507. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  508. __syncthreads();
  509. Tensor tOrO = make_tensor<Element>(shape(tOgO));
  510. cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
  511. Tensor caccO = make_identity_tensor(
  512. Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  513. Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
  514. static_assert(decltype(size<0>(taccOcO))::value == 4);
  515. // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
  516. Tensor taccOcO_row =
  517. logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
  518. CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
  519. if (get<1>(taccOcO_row(0)) == 0) {
  520. #pragma unroll
  521. for (int mi = 0; mi < size(lse); ++mi) {
  522. const int row = get<0>(taccOcO_row(mi));
  523. if (row < binfo.actual_seqlen_q - m_block * kBlockM) {
  524. gLSE(row) = lse(mi);
  525. }
  526. }
  527. }
  528. // Construct identity layout for sO
  529. Tensor cO = make_identity_tensor(
  530. make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  531. // Repeat the partitioning with identity layouts
  532. Tensor tOcO =
  533. gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  534. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
  535. if (!Is_even_K) {
  536. #pragma unroll
  537. for (int k = 0; k < size(tOpO); ++k) {
  538. tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
  539. }
  540. }
  541. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  542. flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false,
  543. /*Clear_OOB_K=*/false>(gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO,
  544. binfo.actual_seqlen_q - m_block * kBlockM);
  545. }
  546. ////////////////////////////////////////////////////////////////////////////////////////////////////
  547. template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi,
  548. bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split,
  549. bool Append_KV, typename Params>
  550. inline __device__ void compute_attn_1rowblock_splitkv(
  551. const Params& params, const int bidb, const int bidh, const int m_block,
  552. const int n_split_idx, const int num_n_splits) {
  553. using Element = typename Kernel_traits::Element;
  554. using ElementAccum = typename Kernel_traits::ElementAccum;
  555. using index_t = typename Kernel_traits::index_t;
  556. // Shared memory.
  557. extern __shared__ char smem_[];
  558. // The thread index.
  559. const int tidx = threadIdx.x;
  560. constexpr int kBlockM = Kernel_traits::kBlockM;
  561. constexpr int kBlockN = Kernel_traits::kBlockN;
  562. constexpr int kHeadDim = Kernel_traits::kHeadDim;
  563. constexpr int kNWarps = Kernel_traits::kNWarps;
  564. using GmemTiledCopyO =
  565. std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO,
  566. typename Kernel_traits::GmemTiledCopyOaccum>;
  567. using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
  568. const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
  569. // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
  570. // printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d,
  571. // actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative,
  572. // binfo.seqlen_k_cache, binfo.actual_seqlen_k); } if (threadIdx.x == 0 &&
  573. // blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p,
  574. // seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache
  575. // + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
  576. if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
  577. const int n_blocks_per_split =
  578. ((binfo.actual_seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) /
  579. num_n_splits;
  580. const int n_block_min =
  581. !Is_local ? n_split_idx * n_blocks_per_split
  582. : std::max(n_split_idx * n_blocks_per_split,
  583. (m_block * kBlockM + binfo.actual_seqlen_k -
  584. binfo.actual_seqlen_q - params.window_size_left) /
  585. kBlockN);
  586. int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN),
  587. (n_split_idx + 1) * n_blocks_per_split);
  588. if (Is_causal || Is_local) {
  589. n_block_max = std::min(
  590. n_block_max,
  591. cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k -
  592. binfo.actual_seqlen_q + params.window_size_right,
  593. kBlockN));
  594. }
  595. if (n_block_min >=
  596. n_block_max) { // This also covers the case where n_block_max <= 0
  597. // We exit early and write 0 to gOaccum and -inf to gLSEaccum.
  598. // Otherwise we might read OOB elements from gK and gV,
  599. // or get wrong results when we combine gOaccum from different blocks.
  600. const index_t row_offset_o =
  601. binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) +
  602. m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
  603. const index_t row_offset_oaccum =
  604. (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q +
  605. m_block * kBlockM) *
  606. params.d_rounded;
  607. const index_t row_offset_lseaccum =
  608. ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q +
  609. m_block * kBlockM;
  610. Tensor gOaccum = make_tensor(
  611. make_gmem_ptr(reinterpret_cast<ElementO*>(Split ? params.oaccum_ptr
  612. : params.o_ptr) +
  613. (Split ? row_offset_oaccum : row_offset_o)),
  614. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  615. make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
  616. Tensor gLSEaccum = make_tensor(
  617. make_gmem_ptr(
  618. reinterpret_cast<ElementAccum*>(Split ? params.softmax_lseaccum_ptr
  619. : params.softmax_lse_ptr) +
  620. row_offset_lseaccum),
  621. Shape<Int<kBlockM>>{}, Stride<_1>{});
  622. GmemTiledCopyO gmem_tiled_copy_Oaccum;
  623. auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
  624. Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
  625. Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
  626. clear(tOrOaccum);
  627. // Construct identity layout for sO
  628. Tensor cO = make_identity_tensor(make_shape(
  629. size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  630. // Repeat the partitioning with identity layouts
  631. Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
  632. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
  633. if (!Is_even_K) {
  634. #pragma unroll
  635. for (int k = 0; k < size(tOpO); ++k) {
  636. tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
  637. }
  638. }
  639. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  640. flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false,
  641. /*Clear_OOB_K=*/false>(
  642. gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO,
  643. binfo.actual_seqlen_q - m_block * kBlockM);
  644. #pragma unroll
  645. for (int m = 0; m < size<1>(tOgOaccum); ++m) {
  646. const int row = get<0>(tOcO(0, m, 0));
  647. if (row < binfo.actual_seqlen_q - m_block * kBlockM &&
  648. get<1>(tOcO(0, m, 0)) == 0) {
  649. gLSEaccum(row) = Split ? -INFINITY : INFINITY;
  650. }
  651. }
  652. return;
  653. }
  654. // We iterate over the blocks in reverse order. This is because the last block
  655. // is the only one that needs masking when we read K and V from global memory.
  656. // Moreover, iterating in reverse might save us 1 register (we just need
  657. // n_block instead of both n_block and n_block_max).
  658. // We move K and V to the last block.
  659. const int bidb_cache =
  660. params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
  661. const int* block_table =
  662. params.block_table == nullptr
  663. ? nullptr
  664. : params.block_table + bidb * params.block_table_batch_stride;
  665. const index_t row_offset_k =
  666. block_table == nullptr
  667. ? binfo.k_offset(params.k_batch_stride, params.k_row_stride,
  668. bidb_cache) +
  669. (n_block_max - 1) * kBlockN * params.k_row_stride +
  670. (bidh / params.h_h_k_ratio) * params.k_head_stride
  671. : (bidh / params.h_h_k_ratio) *
  672. params.k_head_stride; // block addresses are later resolved
  673. // per-thread
  674. const index_t row_offset_v =
  675. block_table == nullptr
  676. ? binfo.k_offset(params.v_batch_stride, params.v_row_stride,
  677. bidb_cache) +
  678. (n_block_max - 1) * kBlockN * params.v_row_stride +
  679. (bidh / params.h_h_k_ratio) * params.v_head_stride
  680. : (bidh / params.h_h_k_ratio) * params.v_head_stride;
  681. Tensor mQ =
  682. make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) +
  683. binfo.q_offset(params.q_batch_stride,
  684. params.q_row_stride, bidb)),
  685. make_shape(binfo.actual_seqlen_q, params.h, params.d),
  686. make_stride(params.q_row_stride, params.q_head_stride, _1{}));
  687. Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
  688. make_coord(m_block, 0)); // (kBlockM, kHeadDim)
  689. Tensor gK = make_tensor(
  690. make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k),
  691. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  692. make_stride(params.k_row_stride, _1{}));
  693. // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr
  694. // = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k,
  695. // gK.data()); }
  696. Tensor gV = make_tensor(
  697. make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v),
  698. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  699. make_stride(params.v_row_stride, _1{}));
  700. Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element*>(smem_)),
  701. typename Kernel_traits::SmemLayoutQ{});
  702. Tensor sK =
  703. make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
  704. Tensor sV =
  705. make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
  706. Tensor sVt =
  707. make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
  708. Tensor sVtNoSwizzle =
  709. make_tensor(sV.data().get(),
  710. typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
  711. typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q;
  712. auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
  713. typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV;
  714. auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);
  715. Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
  716. Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
  717. Tensor tKgK_ = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
  718. Tensor tKsK_ = gmem_thr_copy_KV.partition_D(sK);
  719. Tensor tVgV_ = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
  720. Tensor tVsV_ = gmem_thr_copy_KV.partition_D(sV);
  721. Tensor tKgK = make_tensor(tKgK_.data(), reshape_thread_tile(tKgK_.layout()));
  722. Tensor tKsK = make_tensor(tKsK_.data(), reshape_thread_tile(tKsK_.layout()));
  723. Tensor tVgV = make_tensor(tVgV_.data(), reshape_thread_tile(tVgV_.layout()));
  724. Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout()));
  725. if (block_table != nullptr) {
  726. tKgK.data() =
  727. gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
  728. tidx, n_block_max, params.page_block_size, block_table,
  729. params.k_batch_stride, params.k_row_stride);
  730. tVgV.data() =
  731. gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
  732. tidx, n_block_max, params.page_block_size, block_table,
  733. params.v_batch_stride, params.v_row_stride);
  734. }
  735. typename Kernel_traits::TiledMma tiled_mma;
  736. auto thr_mma = tiled_mma.get_thread_slice(tidx);
  737. Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
  738. Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
  739. Tensor tOrVt =
  740. thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
  741. Tensor acc_o = partition_fragment_C(
  742. tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
  743. //
  744. // Copy Atom retiling
  745. //
  746. auto smem_tiled_copy_Q =
  747. make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
  748. auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
  749. Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
  750. auto smem_tiled_copy_K =
  751. make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
  752. auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
  753. Tensor tSsK = smem_thr_copy_K.partition_S(sK);
  754. auto smem_tiled_copy_V = make_tiled_copy_B(
  755. typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
  756. auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
  757. Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
  758. // PREDICATES
  759. //
  760. // // Allocate predicate tensors for m and n
  761. // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)),
  762. // Stride<_1,_0>{}); Tensor tKVpKV =
  763. // make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)),
  764. // Stride<_1,_0>{});
  765. // Construct identity layout for sQ and sK
  766. Tensor cQ = make_identity_tensor(
  767. make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  768. Tensor cKV = make_identity_tensor(
  769. make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
  770. // Repeat the partitioning with identity layouts
  771. Tensor tQcQ =
  772. gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  773. Tensor tKVcKV_ = gmem_thr_copy_KV.partition_S(
  774. cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
  775. Tensor tKVcKV =
  776. make_tensor(tKVcKV_.data(), reshape_thread_tile(tKVcKV_.layout()));
  777. // Allocate predicate tensors for k
  778. Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
  779. Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
  780. // Set predicates for k bounds
  781. if (!Is_even_K) {
  782. #pragma unroll
  783. for (int k = 0; k < size(tQpQ); ++k) {
  784. tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d;
  785. }
  786. #pragma unroll
  787. for (int k = 0; k < size(tKVpKV); ++k) {
  788. tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
  789. }
  790. }
  791. // Prologue
  792. // Copy from Knew to K, optionally apply rotary embedding.
  793. if constexpr (Append_KV) {
  794. typename Kernel_traits::GmemTiledCopyRotcossinPaged gmem_tiled_copy_rotary;
  795. auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
  796. typename Kernel_traits::GmemTiledCopyRotcossinContPaged
  797. gmem_tiled_copy_rotary_cont;
  798. auto gmem_thr_copy_rotary_cont =
  799. gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
  800. // Even if we have MQA / GQA, all threadblocks responsible for the same KV
  801. // head are writing to gmem. Technically it's a race condition, but they all
  802. // write the same content anyway, and it's safe. We want to do this so that
  803. // all threadblocks can proceed right after they finish writing the KV
  804. // cache.
  805. const index_t row_offset_cossin =
  806. ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
  807. Tensor gCos = make_tensor(
  808. make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) +
  809. row_offset_cossin),
  810. Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
  811. make_stride(params.rotary_dim / 2, _1{}));
  812. Tensor gSin = make_tensor(
  813. make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) +
  814. row_offset_cossin),
  815. Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
  816. make_stride(params.rotary_dim / 2, _1{}));
  817. Tensor gCosCont = make_tensor(
  818. make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) +
  819. row_offset_cossin),
  820. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  821. make_stride(params.rotary_dim / 2, _1{}));
  822. Tensor gSinCont = make_tensor(
  823. make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) +
  824. row_offset_cossin),
  825. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  826. make_stride(params.rotary_dim / 2, _1{}));
  827. Tensor tRgCos_ = gmem_thr_copy_rotary.partition_S(gCos);
  828. Tensor tRgSin_ = gmem_thr_copy_rotary.partition_S(gSin);
  829. Tensor tRgCosCont_ = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
  830. Tensor tRgSinCont_ = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
  831. Tensor tRgCos =
  832. make_tensor(tRgCos_.data(), reshape_thread_tile(tRgCos_.layout()));
  833. Tensor tRgSin =
  834. make_tensor(tRgSin_.data(), reshape_thread_tile(tRgSin_.layout()));
  835. Tensor tRgCosCont = make_tensor(
  836. tRgCosCont_.data(), reshape_flatten_thread_tile(tRgCosCont_.layout()));
  837. Tensor tRgSinCont = make_tensor(
  838. tRgSinCont_.data(), reshape_flatten_thread_tile(tRgSinCont_.layout()));
  839. // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p,
  840. // tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr,
  841. // gCos.data(), tRgCos.data(), params.rotary_dim); } if (cute::thread(8, 0))
  842. // { print_tensor(gCos); } if (cute::thread(0, 0)) { print_tensor(tRgCos); }
  843. const index_t row_offset_knew =
  844. binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) +
  845. ((n_block_max - 1) * kBlockN) * params.knew_row_stride +
  846. (bidh / params.h_h_k_ratio) * params.knew_head_stride;
  847. const index_t row_offset_vnew =
  848. binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) +
  849. ((n_block_max - 1) * kBlockN) * params.vnew_row_stride +
  850. (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
  851. // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew
  852. // "line up". When we access them, e.g. if gK has 128 rows and gKnew has 64
  853. // rows, we access gK[:128] and gKNew[128:128 + 64]. This maps to accessing
  854. // the first 64 rows of knew_ptr.
  855. Tensor gKnew = make_tensor(
  856. make_gmem_ptr(reinterpret_cast<Element*>(params.knew_ptr) +
  857. row_offset_knew -
  858. binfo.seqlen_k_cache * params.knew_row_stride),
  859. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  860. make_stride(params.knew_row_stride, _1{}));
  861. // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
  862. // printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n",
  863. // params.knew_ptr, row_offset_knew, gKnew.data()); }
  864. Tensor gVnew = make_tensor(
  865. make_gmem_ptr(reinterpret_cast<Element*>(params.vnew_ptr) +
  866. row_offset_vnew -
  867. binfo.seqlen_k_cache * params.vnew_row_stride),
  868. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  869. make_stride(params.vnew_row_stride, _1{}));
  870. typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV_new;
  871. auto gmem_thr_copy_KV_new = gmem_tiled_copy_KV_new.get_thread_slice(tidx);
  872. Tensor tKgKnew_ =
  873. gmem_thr_copy_KV_new.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
  874. Tensor tVgVnew_ =
  875. gmem_thr_copy_KV_new.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
  876. auto tKgKnew =
  877. make_tensor(tKgKnew_.data(), reshape_thread_tile(tKgKnew_.layout()));
  878. auto tVgVnew =
  879. make_tensor(tVgVnew_.data(), reshape_thread_tile(tVgVnew_.layout()));
  880. const int n_block_copy_min =
  881. std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
  882. auto tKgK_data = tKgK.data();
  883. auto tVgV_data = tVgV.data();
  884. for (int n_block = n_block_max - 1; n_block >= n_block_copy_min;
  885. n_block--) {
  886. flash::copy_w_min_idx<Is_even_K>(
  887. tVgVnew, tVgV, tKVcKV, tKVpKV,
  888. binfo.actual_seqlen_k - n_block * kBlockN,
  889. binfo.seqlen_k_cache - n_block * kBlockN);
  890. tVgVnew.data() =
  891. tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
  892. if (params.rotary_dim == 0) {
  893. flash::copy_w_min_idx<Is_even_K>(
  894. tKgKnew, tKgK, tKVcKV, tKVpKV,
  895. binfo.actual_seqlen_k - n_block * kBlockN,
  896. binfo.seqlen_k_cache - n_block * kBlockN);
  897. } else {
  898. if (params.is_rotary_interleaved) {
  899. // Don't clear OOB_K because we're writing to global memory
  900. flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
  901. tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV,
  902. binfo.actual_seqlen_k - n_block * kBlockN,
  903. binfo.seqlen_k_cache - n_block * kBlockN, params.d,
  904. params.rotary_dim);
  905. tRgCos.data() =
  906. tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
  907. tRgSin.data() =
  908. tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
  909. } else {
  910. // Don't clear OOB_K because we're writing to global memory
  911. flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
  912. tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV,
  913. binfo.actual_seqlen_k - n_block * kBlockN,
  914. binfo.seqlen_k_cache - n_block * kBlockN, params.d,
  915. params.rotary_dim);
  916. tRgCosCont.data() =
  917. tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
  918. tRgSinCont.data() =
  919. tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
  920. }
  921. }
  922. tKgKnew.data() =
  923. tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
  924. if (block_table == nullptr) {
  925. tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
  926. tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
  927. } else {
  928. if (n_block > n_block_copy_min) {
  929. tVgV.data() =
  930. gV.data() +
  931. flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
  932. tidx, n_block, params.page_block_size, block_table,
  933. params.v_batch_stride, params.v_row_stride);
  934. tKgK.data() =
  935. gK.data() +
  936. flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
  937. tidx, n_block, params.page_block_size, block_table,
  938. params.k_batch_stride, params.k_row_stride);
  939. }
  940. }
  941. }
  942. // Need this before we can read in K again, so that we'll see the updated K
  943. // values.
  944. __syncthreads();
  945. tKgK.data() = tKgK_data;
  946. tVgV.data() = tVgV_data;
  947. }
  948. // Read Q from gmem to smem, optionally apply rotary embedding.
  949. if (!Append_KV || params.rotary_dim == 0) {
  950. // We don't need to clear the sQ smem tiles since we'll only write out the
  951. // valid outputs
  952. flash::copy<Is_even_MN, Is_even_K>(
  953. gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
  954. binfo.actual_seqlen_q - m_block * kBlockM);
  955. } else {
  956. typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
  957. auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
  958. typename Kernel_traits::GmemTiledCopyRotcossinCont
  959. gmem_tiled_copy_rotary_cont;
  960. auto gmem_thr_copy_rotary_cont =
  961. gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
  962. const index_t row_offset_cossin =
  963. (binfo.seqlen_k_cache +
  964. (Is_causal || Is_local ? m_block * kBlockM : 0)) *
  965. (params.rotary_dim / 2);
  966. // If not causal, all the queries get the same the cos/sin, taken at
  967. // location seqlen_k_cache. We do this by setting the row stride of gCos /
  968. // gSin to 0.
  969. Tensor gCos = make_tensor(
  970. make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) +
  971. row_offset_cossin),
  972. Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
  973. make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
  974. Tensor gSin = make_tensor(
  975. make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) +
  976. row_offset_cossin),
  977. Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
  978. make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
  979. Tensor gCosCont = make_tensor(
  980. make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) +
  981. row_offset_cossin),
  982. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  983. make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
  984. Tensor gSinCont = make_tensor(
  985. make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) +
  986. row_offset_cossin),
  987. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  988. make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
  989. Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
  990. Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
  991. Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
  992. Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
  993. if (params.is_rotary_interleaved) {
  994. flash::copy_rotary_interleaved<Is_even_K>(
  995. tQgQ, tQsQ, tRgCos, tRgSin, tQcQ,
  996. binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d,
  997. params.rotary_dim);
  998. } else {
  999. flash::copy_rotary_contiguous<Is_even_K>(
  1000. tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ,
  1001. binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d,
  1002. params.rotary_dim);
  1003. }
  1004. }
  1005. int n_block = n_block_max - 1;
  1006. // We don't need to clear the sK smem tiles since we'll mask out the scores
  1007. // anyway.
  1008. flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV,
  1009. tKVpKV,
  1010. binfo.actual_seqlen_k - n_block * kBlockN);
  1011. cute::cp_async_fence();
  1012. // flash::cp_async_wait<0>();
  1013. // __syncthreads();
  1014. // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
  1015. // __syncthreads();
  1016. clear(acc_o);
  1017. flash::Softmax<2 * size<1>(acc_o)> softmax;
  1018. const float alibi_slope =
  1019. !Has_alibi ? 0.0f
  1020. : reinterpret_cast<float*>(params.alibi_slopes_ptr)
  1021. [bidb * params.alibi_slopes_batch_stride + bidh] /
  1022. params.scale_softmax;
  1023. flash::Mask<Is_causal, Is_local, Has_alibi> mask(
  1024. binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left,
  1025. params.window_size_right, alibi_slope);
  1026. // For performance reason, we separate out two kinds of iterations:
  1027. // those that need masking on S, and those that don't.
  1028. // We need masking on S for the very last block when K and V has length not
  1029. // multiple of kBlockN. We also need masking on S if it's causal, for the last
  1030. // ceil_div(kBlockM, kBlockN) blocks. We will have at least 1 "masking"
  1031. // iteration.
  1032. // If not even_N, then seqlen_k might end in the middle of a block. In that
  1033. // case we need to mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
  1034. constexpr int n_masking_steps =
  1035. (!Is_causal && !Is_local)
  1036. ? 1
  1037. : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN)
  1038. : cute::ceil_div(kBlockM, kBlockN) + 1);
  1039. #pragma unroll
  1040. for (int masking_step = 0; masking_step < n_masking_steps;
  1041. ++masking_step, --n_block) {
  1042. Tensor acc_s = partition_fragment_C(
  1043. tiled_mma,
  1044. Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
  1045. clear(acc_s);
  1046. flash::cp_async_wait<0>();
  1047. __syncthreads();
  1048. // Advance gV
  1049. if (masking_step > 0) {
  1050. if (block_table == nullptr) {
  1051. tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
  1052. } else {
  1053. tVgV.data() =
  1054. gV.data() +
  1055. flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
  1056. tidx, n_block + 1, params.page_block_size, block_table,
  1057. params.v_batch_stride, params.v_row_stride);
  1058. }
  1059. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV,
  1060. tVsV, tKVcKV, tKVpKV);
  1061. } else {
  1062. // Clear the smem tiles to account for predicated off loads
  1063. flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
  1064. gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV,
  1065. binfo.actual_seqlen_k - n_block * kBlockN);
  1066. }
  1067. cute::cp_async_fence();
  1068. flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q,
  1069. smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K);
  1070. // if (cute::thread0()) { print(acc_s); }
  1071. if constexpr (Is_softcap) {
  1072. apply_softcap(acc_s, params.softcap);
  1073. }
  1074. mask.template apply_mask<Is_causal, Is_even_MN>(
  1075. acc_s, n_block * kBlockN,
  1076. m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
  1077. flash::cp_async_wait<0>();
  1078. __syncthreads();
  1079. // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
  1080. // __syncthreads();
  1081. if (n_block > n_block_min) {
  1082. // Advance gK
  1083. if (block_table == nullptr) {
  1084. tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
  1085. } else {
  1086. tKgK.data() = gK.data() +
  1087. flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
  1088. tidx, n_block, params.page_block_size, block_table,
  1089. params.k_batch_stride, params.k_row_stride);
  1090. }
  1091. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK,
  1092. tKsK, tKVcKV, tKVpKV);
  1093. // This cp_async_fence needs to be in the if block, otherwise the
  1094. // synchronization isn't right and we get race conditions.
  1095. cute::cp_async_fence();
  1096. }
  1097. // We have key_padding_mask so we'll need to Check_inf
  1098. masking_step == 0
  1099. ? softmax.template softmax_rescale_o</*Is_first=*/true,
  1100. /*Check_inf=*/Is_causal ||
  1101. Is_local || !Is_even_MN>(
  1102. acc_s, acc_o, params.scale_softmax_log2)
  1103. : softmax.template softmax_rescale_o</*Is_first=*/false,
  1104. /*Check_inf=*/Is_causal ||
  1105. Is_local || !Is_even_MN>(
  1106. acc_s, acc_o, params.scale_softmax_log2);
  1107. // if (cute::thread0()) { print(scores_max); print(scores_sum);
  1108. // print(scores); }
  1109. // Convert acc_s from fp32 to fp16/bf16
  1110. Tensor rP = flash::convert_type<Element>(acc_s);
  1111. // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  1112. // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
  1113. Tensor tOrP = make_tensor(
  1114. rP.data(),
  1115. flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
  1116. flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V,
  1117. smem_thr_copy_V);
  1118. // This check is at the end of the loop since we always have at least 1
  1119. // iteration
  1120. if (n_masking_steps > 1 && n_block <= n_block_min) {
  1121. --n_block;
  1122. break;
  1123. }
  1124. }
  1125. // These are the iterations where we don't need masking on S
  1126. for (; n_block >= n_block_min; --n_block) {
  1127. Tensor acc_s = partition_fragment_C(
  1128. tiled_mma,
  1129. Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
  1130. clear(acc_s);
  1131. flash::cp_async_wait<0>();
  1132. __syncthreads();
  1133. // Advance gV
  1134. if (block_table == nullptr) {
  1135. tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
  1136. } else {
  1137. tVgV.data() = gV.data() +
  1138. flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
  1139. tidx, n_block + 1, params.page_block_size, block_table,
  1140. params.v_batch_stride, params.v_row_stride);
  1141. }
  1142. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV,
  1143. tKVcKV, tKVpKV);
  1144. cute::cp_async_fence();
  1145. flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q,
  1146. smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K);
  1147. if constexpr (Is_softcap) {
  1148. apply_softcap(acc_s, params.softcap);
  1149. }
  1150. flash::cp_async_wait<0>();
  1151. __syncthreads();
  1152. if (n_block > n_block_min) {
  1153. // Advance gK
  1154. if (block_table == nullptr) {
  1155. tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
  1156. } else {
  1157. tKgK.data() = gK.data() +
  1158. flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
  1159. tidx, n_block, params.page_block_size, block_table,
  1160. params.k_batch_stride, params.k_row_stride);
  1161. }
  1162. flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK,
  1163. tKsK, tKVcKV, tKVpKV);
  1164. // This cp_async_fence needs to be in the if block, otherwise the
  1165. // synchronization isn't right and we get race conditions.
  1166. cute::cp_async_fence();
  1167. }
  1168. mask.template apply_mask</*Causal_mask=*/false>(
  1169. acc_s, n_block * kBlockN,
  1170. m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
  1171. softmax
  1172. .template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(
  1173. acc_s, acc_o, params.scale_softmax_log2);
  1174. Tensor rP = flash::convert_type<Element>(acc_s);
  1175. // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  1176. // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
  1177. Tensor tOrP = make_tensor(
  1178. rP.data(),
  1179. flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
  1180. flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V,
  1181. smem_thr_copy_V);
  1182. }
  1183. // Epilogue
  1184. Tensor lse =
  1185. softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(
  1186. acc_o, params.scale_softmax);
  1187. // if (cute::thread0()) { print(lse); }
  1188. Tensor sOaccum =
  1189. make_tensor(make_smem_ptr(reinterpret_cast<ElementO*>(smem_)),
  1190. typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
  1191. // Partition sO to match the accumulator partitioning
  1192. using SmemTiledCopyO =
  1193. std::conditional_t<!Split, typename Kernel_traits::SmemCopyAtomO,
  1194. typename Kernel_traits::SmemCopyAtomOaccum>;
  1195. auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
  1196. auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
  1197. Tensor rO = flash::convert_type<ElementO>(acc_o);
  1198. Tensor taccOrOaccum =
  1199. smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
  1200. Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(
  1201. sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  1202. // sOaccum is larger than sQ, so we need to syncthreads here
  1203. // TODO: allocate enough smem for sOaccum
  1204. if constexpr (Split) {
  1205. __syncthreads();
  1206. }
  1207. cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
  1208. const index_t row_offset_o =
  1209. binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) +
  1210. m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
  1211. const index_t row_offset_oaccum =
  1212. (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q +
  1213. m_block * kBlockM) *
  1214. params.d_rounded;
  1215. const index_t row_offset_lseaccum =
  1216. (Split || !params.unpadded_lse
  1217. ? ((n_split_idx * params.b + bidb) * params.h + bidh) *
  1218. params.seqlen_q
  1219. : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)) +
  1220. m_block * kBlockM;
  1221. Tensor gOaccum =
  1222. make_tensor(make_gmem_ptr(reinterpret_cast<ElementO*>(
  1223. Split ? params.oaccum_ptr : params.o_ptr) +
  1224. (Split ? row_offset_oaccum : row_offset_o)),
  1225. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  1226. make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
  1227. Tensor gLSEaccum = make_tensor(
  1228. make_gmem_ptr(
  1229. reinterpret_cast<ElementAccum*>(Split ? params.softmax_lseaccum_ptr
  1230. : params.softmax_lse_ptr) +
  1231. row_offset_lseaccum),
  1232. Shape<Int<kBlockM>>{}, Stride<_1>{});
  1233. // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n",
  1234. // row_offset_o, bidh, gOaccum.data()); }
  1235. GmemTiledCopyO gmem_tiled_copy_Oaccum;
  1236. auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
  1237. Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(
  1238. sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  1239. Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
  1240. __syncthreads();
  1241. Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
  1242. cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
  1243. Tensor caccO = make_identity_tensor(
  1244. Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  1245. Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
  1246. static_assert(decltype(size<0>(taccOcO))::value == 4);
  1247. // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
  1248. Tensor taccOcO_row =
  1249. logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
  1250. CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
  1251. if (get<1>(taccOcO_row(0)) == 0) {
  1252. #pragma unroll
  1253. for (int mi = 0; mi < size(lse); ++mi) {
  1254. const int row = get<0>(taccOcO_row(mi));
  1255. if (row < binfo.actual_seqlen_q - m_block * kBlockM) {
  1256. gLSEaccum(row) = lse(mi);
  1257. }
  1258. }
  1259. }
  1260. // Construct identity layout for sO
  1261. Tensor cO = make_identity_tensor(make_shape(
  1262. size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  1263. // Repeat the partitioning with identity layouts
  1264. Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(
  1265. cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
  1266. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
  1267. if (!Is_even_K) {
  1268. #pragma unroll
  1269. for (int k = 0; k < size(tOpO); ++k) {
  1270. tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
  1271. }
  1272. }
  1273. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  1274. flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false,
  1275. /*Clear_OOB_K=*/false>(gmem_tiled_copy_Oaccum, tOrOaccum,
  1276. tOgOaccum, tOcO, tOpO,
  1277. binfo.actual_seqlen_q - m_block * kBlockM);
  1278. }
  1279. ////////////////////////////////////////////////////////////////////////////////////////////////////
  1280. template <typename Kernel_traits, bool Is_dropout, bool Is_causal,
  1281. bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K,
  1282. bool Is_softcap, bool Return_softmax, typename Params>
  1283. inline __device__ void compute_attn(const Params& params) {
  1284. const int m_block = blockIdx.x;
  1285. // The block index for the batch.
  1286. const int bidb = blockIdx.y;
  1287. // The block index for the head.
  1288. const int bidh = blockIdx.z;
  1289. // We want the fwd and bwd to generate the same dropout pattern (RNG), without
  1290. // restricting them to have the same number of threads or have to traverse the
  1291. // attention matrix in the same order. In the Philox RNG, we use the offset to
  1292. // store the batch, head, and the lane id (within a warp). We use the
  1293. // subsequence to store the location of the 16 x 32 blocks within the
  1294. // attention matrix. This way, as long as we have the batch, head, and the
  1295. // location of the 16 x 32 block within the attention matrix, we can generate
  1296. // the exact same dropout pattern.
  1297. flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local,
  1298. Has_alibi, Is_even_MN, Is_even_K, Is_softcap,
  1299. Return_softmax>(params, bidb, bidh, m_block);
  1300. }
  1301. ////////////////////////////////////////////////////////////////////////////////////////////////////
  1302. template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi,
  1303. bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split,
  1304. bool Append_KV, typename Params>
  1305. inline __device__ void compute_attn_splitkv(const Params& params) {
  1306. const int m_block = blockIdx.x;
  1307. // The block index for the batch.
  1308. const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
  1309. // The block index for the head.
  1310. const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
  1311. const int n_split_idx = Split ? blockIdx.y : 0;
  1312. const int num_n_splits = Split ? gridDim.y : 1;
  1313. flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local,
  1314. Has_alibi, Is_even_MN, Is_even_K,
  1315. Is_softcap, Split, Append_KV>(
  1316. params, bidb, bidh, m_block, n_split_idx, num_n_splits);
  1317. }
  1318. ////////////////////////////////////////////////////////////////////////////////////////////////////
  1319. template <typename Kernel_traits, int kBlockM, int Log_max_splits,
  1320. bool Is_even_K, typename Params>
  1321. inline __device__ void combine_attn_seqk_parallel(const Params& params) {
  1322. using Element = typename Kernel_traits::Element;
  1323. using ElementAccum = typename Kernel_traits::ElementAccum;
  1324. using index_t = typename Kernel_traits::index_t;
  1325. constexpr int kMaxSplits = 1 << Log_max_splits;
  1326. constexpr int kHeadDim = Kernel_traits::kHeadDim;
  1327. constexpr int kNThreads = Kernel_traits::kNThreads;
  1328. static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
  1329. static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32,
  1330. "kBlockM must be 4, 8, 16 or 32");
  1331. static_assert(kNThreads == 128, "We assume that each block has 128 threads");
  1332. // Shared memory.
  1333. // kBlockM + 1 instead of kBlockM to reduce bank conflicts.
  1334. __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1];
  1335. // The thread and block index.
  1336. const int tidx = threadIdx.x;
  1337. const int bidx = blockIdx.x;
  1338. const index_t lse_size = params.b * params.h * params.seqlen_q;
  1339. const index_t row_offset_lse = bidx * kBlockM;
  1340. Tensor gLSEaccum = make_tensor(
  1341. make_gmem_ptr(
  1342. reinterpret_cast<ElementAccum*>(params.softmax_lseaccum_ptr) +
  1343. row_offset_lse),
  1344. Shape<Int<kMaxSplits>, Int<kBlockM>>{}, make_stride(lse_size, _1{}));
  1345. // LSE format is different depending on params.unpadded_lse and
  1346. // params.seqlenq_ngroups_swapped, see comment in get_lse_tile. This tensor's
  1347. // layout maps row_offset_lse to {bidb, bidh, q_offset}.
  1348. Tensor gLSE = make_tensor(
  1349. make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) +
  1350. row_offset_lse),
  1351. Shape<Int<kBlockM>>{}, Stride<_1>{});
  1352. // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb,
  1353. // q_offset}.
  1354. Layout flat_layout = make_layout(lse_size);
  1355. Layout orig_layout =
  1356. make_layout(make_shape(params.seqlen_q, params.h, params.b));
  1357. auto transposed_stride =
  1358. params.seqlenq_ngroups_swapped
  1359. ? make_stride(params.b, params.seqlen_q * params.b, 1)
  1360. : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
  1361. Layout remapped_layout = make_layout(
  1362. make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
  1363. Layout final_layout = cute::composition(
  1364. remapped_layout, cute::composition(orig_layout, flat_layout));
  1365. Tensor gLSE_unpadded = make_tensor(
  1366. make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
  1367. final_layout);
  1368. constexpr int kNLsePerThread =
  1369. (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
  1370. // Read the LSE values from gmem and store them in shared memory, then
  1371. // transpose them.
  1372. constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
  1373. #pragma unroll
  1374. for (int l = 0; l < kNLsePerThread; ++l) {
  1375. const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
  1376. const int col = tidx % kBlockM;
  1377. ElementAccum lse =
  1378. (row < params.num_splits && col < lse_size - bidx * kBlockM)
  1379. ? gLSEaccum(row, col)
  1380. : -INFINITY;
  1381. if (row < kMaxSplits) {
  1382. sLSE[row][col] = lse;
  1383. }
  1384. // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse
  1385. // = %f\n", tidx, row, col, lse); }
  1386. }
  1387. // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse =
  1388. // %f\n", tidx, row_offset_lse, lse_accum(0)); }
  1389. __syncthreads();
  1390. Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
  1391. constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
  1392. // To make sure that kMaxSplits is within 1 warp: we decide how many elements
  1393. // within kMaxSplits each thread should hold. If kMaxSplits = 16, then each
  1394. // thread holds 2 elements (128 threads, kBlockM rows, so each time we load we
  1395. // can load 128 / kBlockM rows). constexpr int kThreadsPerSplit = kMaxSplits /
  1396. // kRowsPerLoadTranspose; static_assert(kThreadsPerSplit <= 32);
  1397. static_assert(kRowsPerLoadTranspose <= 32);
  1398. static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
  1399. #pragma unroll
  1400. for (int l = 0; l < kNLsePerThread; ++l) {
  1401. const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
  1402. const int col = tidx / kRowsPerLoadTranspose;
  1403. lse_accum(l) =
  1404. (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY;
  1405. // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse
  1406. // = %f\n", tidx, row, col, lse_accum(l)); }
  1407. }
  1408. // Compute the logsumexp of the LSE along the split dimension.
  1409. ElementAccum lse_max = lse_accum(0);
  1410. #pragma unroll
  1411. for (int l = 1; l < kNLsePerThread; ++l) {
  1412. lse_max = max(lse_max, lse_accum(l));
  1413. }
  1414. MaxOp<float> max_op;
  1415. lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
  1416. lse_max =
  1417. lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
  1418. float lse_sum = expf(lse_accum(0) - lse_max);
  1419. #pragma unroll
  1420. for (int l = 1; l < kNLsePerThread; ++l) {
  1421. lse_sum += expf(lse_accum(l) - lse_max);
  1422. }
  1423. SumOp<float> sum_op;
  1424. lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
  1425. // For the case where all local lse == -INFINITY, we want to set lse_logsum to
  1426. // INFINITY. Otherwise lse_logsum is log(0.0) = -INFINITY and we get NaN when
  1427. // we do lse_accum(l) - lse_logsum.
  1428. ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum)
  1429. ? INFINITY
  1430. : logf(lse_sum) + lse_max;
  1431. // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f,
  1432. // lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
  1433. if (tidx % kRowsPerLoadTranspose == 0 &&
  1434. tidx / kRowsPerLoadTranspose < kBlockM) {
  1435. if (params.unpadded_lse) {
  1436. const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
  1437. if (lse_offset < lse_size) {
  1438. gLSE_unpadded(lse_offset) = lse_logsum;
  1439. }
  1440. } else {
  1441. gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
  1442. }
  1443. }
  1444. // Store the scales exp(lse - lse_logsum) in shared memory.
  1445. #pragma unroll
  1446. for (int l = 0; l < kNLsePerThread; ++l) {
  1447. const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
  1448. const int col = tidx / kRowsPerLoadTranspose;
  1449. if (row < params.num_splits && col < kBlockM) {
  1450. sLSE[row][col] = expf(lse_accum(l) - lse_logsum);
  1451. }
  1452. }
  1453. __syncthreads();
  1454. const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
  1455. Tensor gOaccum = make_tensor(
  1456. make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.oaccum_ptr) +
  1457. row_offset_oaccum),
  1458. Shape<Int<kBlockM>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
  1459. constexpr int kBlockN = kNThreads / kBlockM;
  1460. using GmemLayoutAtomOaccum =
  1461. Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
  1462. using GmemTiledCopyOaccum = decltype(make_tiled_copy(
  1463. Copy_Atom<DefaultCopy, ElementAccum>{}, GmemLayoutAtomOaccum{},
  1464. Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
  1465. GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
  1466. auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
  1467. Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
  1468. Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
  1469. Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
  1470. clear(tOrO);
  1471. // Predicates
  1472. Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
  1473. // Repeat the partitioning with identity layouts
  1474. Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
  1475. Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
  1476. if (!Is_even_K) {
  1477. #pragma unroll
  1478. for (int k = 0; k < size(tOpOaccum); ++k) {
  1479. tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d;
  1480. }
  1481. }
  1482. // Load Oaccum in then scale and accumulate to O
  1483. for (int split = 0; split < params.num_splits; ++split) {
  1484. flash::copy</*Is_even_MN=*/false, Is_even_K>(
  1485. gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum,
  1486. params.b * params.h * params.seqlen_q - bidx * kBlockM);
  1487. #pragma unroll
  1488. for (int m = 0; m < size<1>(tOrOaccum); ++m) {
  1489. int row = get<0>(tOcOaccum(0, m, 0));
  1490. ElementAccum lse_scale = sLSE[split][row];
  1491. #pragma unroll
  1492. for (int k = 0; k < size<2>(tOrOaccum); ++k) {
  1493. #pragma unroll
  1494. for (int i = 0; i < size<0>(tOrOaccum); ++i) {
  1495. tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
  1496. }
  1497. }
  1498. // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0],
  1499. // sLSE[split][1]); print(tOrOaccum); }
  1500. }
  1501. tOgOaccum.data() = tOgOaccum.data() +
  1502. params.b * params.h * params.seqlen_q * params.d_rounded;
  1503. }
  1504. // if (cute::thread0()) { print_tensor(tOrO); }
  1505. Tensor rO = flash::convert_type<Element>(tOrO);
  1506. // Write to gO
  1507. #pragma unroll
  1508. for (int m = 0; m < size<1>(rO); ++m) {
  1509. const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
  1510. if (idx < params.b * params.h * params.seqlen_q) {
  1511. const int batch_idx = idx / (params.h * params.seqlen_q);
  1512. const int head_idx =
  1513. (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
  1514. // The index to the rows of Q
  1515. const int row = idx - batch_idx * (params.h * params.seqlen_q) -
  1516. head_idx * params.seqlen_q;
  1517. auto o_ptr = reinterpret_cast<Element*>(params.o_ptr) +
  1518. batch_idx * params.o_batch_stride +
  1519. head_idx * params.o_head_stride + row * params.o_row_stride;
  1520. #pragma unroll
  1521. for (int k = 0; k < size<2>(rO); ++k) {
  1522. if (Is_even_K || tOpOaccum(k)) {
  1523. const int col = get<1>(tOcOaccum(0, m, k));
  1524. Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
  1525. Shape<Int<decltype(size<0>(rO))::value>>{},
  1526. Stride<_1>{});
  1527. // TODO: Should check if this is using vectorized store, but it seems
  1528. // pretty fast
  1529. copy(rO(_, m, k), gO);
  1530. // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d,
  1531. // batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx,
  1532. // batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
  1533. // reinterpret_cast<uint64_t *>(o_ptr)[col / 4] =
  1534. // recast<uint64_t>(rO)(0, m, k);
  1535. }
  1536. }
  1537. }
  1538. }
  1539. }
  1540. } // namespace flash