flash_fwd_combine_kernel.h 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/arch/memory.h>
  8. #include <cutlass/array.h>
  9. #include <cutlass/numeric_types.h>
  10. #include <cutlass/numeric_conversion.h>
  11. #include "seqlen.h"
  12. #include "utils.h"
  13. namespace flash {
  14. using namespace cute;
  15. template <class TileShape_MK_, int kLogMaxSplits_, int kNThreads, int AlignmentLSE_,
  16. bool Is_even_K, bool Varlen, class Element, class ElementPartial, class ArchTag_>
  17. class FlashAttnFwdCombine {
  18. public:
  19. // Type Aliases
  20. using TileShape_MK = TileShape_MK_;
  21. using ArchTag = ArchTag_;
  22. static constexpr int kMaxSplits = 1 << kLogMaxSplits_;
  23. static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float)));
  24. static_assert(AlignmentLSE >= 1);
  25. static constexpr int kStages = 4;
  26. static_assert(ArchTag::kMinComputeCapability >= 75);
  27. static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;
  28. static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
  29. static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
  30. static constexpr int kBlockM = get<0>(TileShape_MK{});
  31. static constexpr int kHeadDim = get<1>(TileShape_MK{});
  32. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial);
  33. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  34. static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
  35. static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
  36. static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
  37. using GmemCopyAtom = std::conditional_t<
  38. Has_cp_async,
  39. cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, ElementPartial>,
  40. cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>
  41. >;
  42. using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  43. Stride<Int<kGmemThreadsPerRow>, _1>>;
  44. static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);
  45. using GmemTiledCopyAccum = decltype(
  46. make_tiled_copy(GmemCopyAtom{},
  47. GmemLayoutAtom{},
  48. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
  49. using GmemTiledCopy = decltype(
  50. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  51. GmemLayoutAtom{},
  52. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
  53. using AlignmentTypeLSE = cute::uint_byte_t<static_cast<int>(sizeof(float)) * AlignmentLSE>;
  54. static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float);
  55. static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE");
  56. static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8");
  57. static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8)));
  58. static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE;
  59. static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE");
  60. using GmemLayoutAtomLSE = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRowLSE>, Int<kGmemThreadsPerRowLSE>>,
  61. Stride<Int<kGmemThreadsPerRowLSE>, _1>>;
  62. static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0);
  63. using GmemCopyAtomLSE = std::conditional_t<
  64. Has_cp_async,
  65. cute::Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeLSE>, float>,
  66. cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<AlignmentLSE * sizeof(float) * 8>, float>
  67. >;
  68. using GmemTiledCopyLSE = decltype(
  69. make_tiled_copy(GmemCopyAtomLSE{},
  70. GmemLayoutAtomLSE{},
  71. Layout<Shape<_1, Int<kGmemElemsPerLoadLSE>>>{})); // Val layout, 4 vals per load
  72. // Otherwise we get IMA when some threads access sLSE, as we're not doing any masking
  73. static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE");
  74. // This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
  75. using SmemLSESwizzle = std::conditional_t<
  76. kBlockMSmem == 8,
  77. Swizzle<5, 0, 5>,
  78. std::conditional_t<kBlockMSmem == 16, Swizzle<4, 0, 4>, Swizzle<3, 2, 3>>
  79. >;
  80. using SmemLayoutAtomLSE =
  81. decltype(composition(SmemLSESwizzle{},
  82. Layout<Shape<Int<8>, Int<kBlockMSmem>>,
  83. Stride<Int<kBlockMSmem>, _1>>{}));
  84. using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape<Int<kMaxSplits>, Int<kBlockM>>{}));
  85. using SmemLayoutO = Layout<Shape<Int<kBlockM>, Int<kHeadDim>, Int<kStages>>,
  86. Stride<Int<kHeadDim>, _1, Int<kBlockM * kHeadDim>>>;
  87. // We want each column (kMaxSplits) to be processed by threads in the same warp.
  88. // To reduce the number of shuffles, we want as few threads on the same column as possible.
  89. // E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column
  90. // have have 64 such quads.
  91. static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem");
  92. static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem;
  93. static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp");
  94. using S2RLayoutAtomLSE = Layout<Shape<Int<kSmemThreadsPerColLSEt>, Int<MaxThreadsPerBlock / kSmemThreadsPerColLSEt>>>;
  95. using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, float>{}, S2RLayoutAtomLSE{}, Layout<_1>{}));
  96. using ShapeOPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, num_splits, head, batch)
  97. using StrideOPartial = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
  98. using ShapeLSEPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, num_splits, head, batch)
  99. using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen, num_splits, head, batch)
  100. using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
  101. using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
  102. using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
  103. using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
  104. struct SharedStorage : cute::aligned_struct<128> {
  105. cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
  106. cute::array_aligned<int, kBlockM> smem_max_valid_split;
  107. cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
  108. };
  109. static constexpr int SharedStorageSize = sizeof(SharedStorage);
  110. // Device side arguments
  111. struct Arguments {
  112. ElementPartial const* ptr_O_partial;
  113. ShapeOPartial const shape_O_partial;
  114. StrideOPartial const stride_O_partial;
  115. float const* ptr_LSE_partial;
  116. ShapeLSEPartial const shape_LSE_partial;
  117. StrideLSEPartial const stride_LSE_partial;
  118. Element* ptr_O;
  119. StrideO const stride_O;
  120. float* ptr_LSE;
  121. StrideLSE const stride_LSE;
  122. int const* cu_seqlens = nullptr;
  123. int const* seqused = nullptr;
  124. };
  125. // Kernel entry point API
  126. struct Params {
  127. ElementPartial const* ptr_O_partial;
  128. ShapeOPartial const shape_O_partial;
  129. StrideOPartial const stride_O_partial;
  130. float const* ptr_LSE_partial;
  131. ShapeLSEPartial const shape_LSE_partial;
  132. StrideLSEPartial const stride_LSE_partial;
  133. Element* ptr_O;
  134. StrideO const stride_O;
  135. float* ptr_LSE;
  136. StrideLSE const stride_LSE;
  137. cutlass::FastDivmod seqlen_divmod, head_divmod;
  138. int const* cu_seqlens = nullptr;
  139. int const* seqused = nullptr;
  140. };
  141. // Convert to underlying arguments. In this case, a simple copy for the aliased type.
  142. static
  143. Params
  144. to_underlying_arguments(Arguments const& args) {
  145. assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
  146. return {
  147. args.ptr_O_partial,
  148. args.shape_O_partial,
  149. args.stride_O_partial,
  150. args.ptr_LSE_partial,
  151. args.shape_LSE_partial,
  152. args.stride_LSE_partial,
  153. args.ptr_O,
  154. args.stride_O,
  155. args.ptr_LSE,
  156. args.stride_LSE,
  157. cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)),
  158. args.cu_seqlens,
  159. args.seqused
  160. };
  161. }
  162. CUTLASS_DEVICE
  163. void
  164. operator()(Params const& params, char* smem_buf) {
  165. SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
  166. Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
  167. Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
  168. Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
  169. int const thread_idx = threadIdx.x;
  170. int const m_block = blockIdx.x;
  171. int const batch = !Varlen ? 0 : blockIdx.y;
  172. int const num_splits = get<1>(params.shape_LSE_partial);
  173. flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
  174. int const offset = seqlen_info.offset;
  175. int const seqlen = seqlen_info.seqlen;
  176. int max_idx = seqlen * get<2>(params.shape_LSE_partial) * get<3>(params.shape_LSE_partial);
  177. cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
  178. // Step 1: load LSE_partial from gmem -> smem
  179. Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)), select<1, 0, 2, 3>(params.shape_LSE_partial), select<1, 0, 2, 3>(params.stride_LSE_partial)); // (num_splits, seqlen, head, batch)
  180. Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int<kGmemElemsPerLoadLSE>>{});
  181. GmemTiledCopyLSE gmem_tiled_copy_LSE;
  182. auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx);
  183. Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE);
  184. // Construct identity layout for sLSE
  185. Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE))); // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m)
  186. // Repeat the partitioning with identity layouts
  187. Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE);
  188. #pragma unroll
  189. for (int m = 0; m < size<2>(tLSEcLSE); ++m) {
  190. int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m)));
  191. int idx = m_block * kBlockM + mi;
  192. if (idx < max_idx) {
  193. int m_idx, bidh, bidb;
  194. if constexpr (!Varlen) {
  195. bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx));
  196. } else {
  197. bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
  198. bidb = 0;
  199. }
  200. Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh, bidb);
  201. #pragma unroll
  202. for (int s = 0; s < size<1>(tLSEcLSE); ++s) {
  203. int si = get<0>(tLSEcLSE(_0{}, s, _0{}));
  204. // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast<int>(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);}
  205. if (si < num_splits) {
  206. cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m));
  207. } else {
  208. cute::fill(tLSEsLSE(_, s, m), -INFINITY);
  209. }
  210. }
  211. } else {
  212. // We don't need to zero out the rest of the LSEs, as we will not write the output to gmem
  213. // cute::fill(tLSEsLSE(_, _, m), -INFINITY);
  214. }
  215. }
  216. if constexpr (Has_cp_async) { cute::cp_async_fence(); }
  217. // Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2.
  218. // We want these async loads to be in flight as we compute the LSE.
  219. GmemTiledCopyAccum gmem_tiled_copy_O_partial;
  220. auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx);
  221. // Construct identity layout for gO
  222. Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  223. // Repeat the partitioning with identity layouts
  224. Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO);
  225. Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), params.shape_O_partial, params.stride_O_partial); // (seqlen, d, num_splits, head, batch)
  226. // Precompute these values to avoid recomputing them in the loop
  227. Tensor tOmidx = make_tensor<int>(make_shape(size<1>(tOcO)));
  228. Tensor tObidh = make_tensor<int>(make_shape(size<1>(tOcO)));
  229. Tensor tObidb = make_tensor<int>(make_shape(size<1>(tOcO)));
  230. Tensor tOrOptr = make_tensor<ElementPartial const*>(make_shape(size<1>(tOcO)));
  231. #pragma unroll
  232. for (int m = 0; m < size<1>(tOcO); ++m) {
  233. int mi = get<0>(tOcO(_0{}, m, _0{}));
  234. int idx = m_block * kBlockM + mi;
  235. if constexpr (!Varlen) {
  236. tObidb[m] = params.head_divmod.divmod(tObidh(m), params.seqlen_divmod.divmod(tOmidx(m), idx));
  237. } else {
  238. tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx);
  239. tObidb[m] = 0;
  240. }
  241. tOrOptr[m] = &mOpartial(tOmidx(m), _0{}, _0{}, tObidh(m), tObidb(m));
  242. if (idx >= max_idx) {
  243. tObidb[m] = -1;
  244. }
  245. }
  246. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
  247. if constexpr (!(Is_even_K)) {
  248. #pragma unroll
  249. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial); }
  250. }
  251. Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO);
  252. auto load_O_partial = [&] (int split, int stage) {
  253. Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage);
  254. #pragma unroll
  255. for (int m = 0; m < size<1>(tOcO); ++m) {
  256. if (tObidb(m) >= 0) {
  257. Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}, _0{}).layout());
  258. Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape<Int<kGmemElemsPerLoad>>{});
  259. #pragma unroll
  260. for (int k = 0; k < size<2>(tOcO); ++k) {
  261. int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
  262. if (Is_even_K || tOpO(k)) {
  263. cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k));
  264. }
  265. }
  266. }
  267. }
  268. };
  269. for (int s = 0; s < kStages - 1; ++s) {
  270. if (s < num_splits) { load_O_partial(s, s); }
  271. if constexpr (Has_cp_async) { cute::cp_async_fence(); }
  272. }
  273. // Step 3: load and transpose LSE_partial from smem -> rmem
  274. if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
  275. __syncthreads();
  276. S2RTiledCopyLSE s2r_tiled_copy_LSE;
  277. auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx);
  278. Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE);
  279. Tensor ts2rrLSE = make_fragment_like(ts2rsLSE);
  280. cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE);
  281. // Step 4: compute the final LSE along the split dimension
  282. Tensor lse_sum = make_tensor<float>(make_shape(size<2>(ts2rrLSE)));
  283. Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE);
  284. // We compute the max valid split for each row to short-circuit the computation later
  285. Tensor max_valid_split = make_tensor<int>(make_shape(size<2>(ts2rrLSE)));
  286. static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1);
  287. #pragma unroll
  288. for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
  289. float lse_max = ts2rrLSE(_0{}, _0{}, m);
  290. #pragma unroll
  291. for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); }
  292. MaxOp<float> max_op;
  293. lse_max = Allreduce<kSmemThreadsPerColLSEt>::run(lse_max, max_op);
  294. int max_valid_idx = -1;
  295. #pragma unroll
  296. for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
  297. if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); }
  298. }
  299. MaxOp<int> max_int_op;
  300. max_valid_split[m] = Allreduce<kSmemThreadsPerColLSEt>::run(max_valid_idx, max_int_op);
  301. float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
  302. float lse_sum_cur = 0.f;
  303. #pragma unroll
  304. for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
  305. float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur);
  306. lse_sum_cur += scale;
  307. // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast<int>(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);}
  308. // ts2rsLSE(_0{}, m, s) = scale;
  309. ts2rrLSE(_0{}, s, m) = scale;
  310. }
  311. SumOp<float> sum_op;
  312. lse_sum_cur = Allreduce<kSmemThreadsPerColLSEt>::run(lse_sum_cur, sum_op);
  313. lse_sum(m) = logf(lse_sum_cur) + lse_max;
  314. float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur;
  315. #pragma unroll
  316. for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; }
  317. }
  318. // Store the scales exp(lse - lse_logsum) back to smem
  319. cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE);
  320. // Step 5: store final LSE back to gmem
  321. auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial);
  322. Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE);
  323. #pragma unroll
  324. for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
  325. if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem
  326. int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
  327. int idx = m_block * kBlockM + mi;
  328. if (idx < max_idx) {
  329. int m_idx, bidh, bidb;
  330. if constexpr (!Varlen) {
  331. bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx));
  332. } else {
  333. bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
  334. bidb = 0;
  335. }
  336. // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m));
  337. mLSE(m_idx, bidh, bidb) = lse_sum(m);
  338. }
  339. if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; }
  340. }
  341. }
  342. // Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O
  343. __syncthreads();
  344. int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))];
  345. #pragma unroll
  346. for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); }
  347. Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor<ElementPartial>(TileShape_MK{})).layout();
  348. Tensor tOrOpartial = make_fragment_like<ElementPartial>(tOrOpartial_layout);
  349. Tensor tOrO = make_fragment_like<float>(tOrOpartial);
  350. clear(tOrO);
  351. int stage_load = kStages - 1, stage_compute = 0;
  352. #pragma unroll 4 // Already tuned for speed
  353. for (int s = 0; s <= thr_max_valid_split; ++s) {
  354. Tensor scale = make_tensor<float>(make_shape(size<1>(tOrOpartial)));
  355. #pragma unroll
  356. for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); }
  357. if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); }
  358. if constexpr (Has_cp_async) { cute::cp_async_fence(); }
  359. stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0;
  360. if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
  361. // We don't need __syncthreads() because each thread is just reading its own data from smem
  362. cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>{},
  363. tOsOpartial(_, _, _, stage_compute), tOrOpartial);
  364. stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0;
  365. #pragma unroll
  366. for (int m = 0; m < size<1>(tOrOpartial); ++m) {
  367. if (tObidb(m) >= 0 && scale(m) > 0.f) {
  368. #pragma unroll
  369. for (int k = 0; k < size<2>(tOrOpartial); ++k) {
  370. if (Is_even_K || tOpO(k)) {
  371. Tensor rOpartial = make_tensor_like<float>(tOrOpartial(_, m, k));
  372. flash::convert_type_out(tOrOpartial(_, m, k), rOpartial);
  373. #pragma unroll
  374. for (int i = 0; i < size<0>(tOrOpartial); ++i) {
  375. tOrO(i, m, k) += scale(m) * rOpartial[i];
  376. }
  377. }
  378. }
  379. }
  380. }
  381. }
  382. // Step 7: Write the final O to gmem
  383. Tensor rO = make_tensor_like<Element>(tOrO);
  384. flash::convert_type_out(tOrO, rO);
  385. auto shape_O = select<0, 1, 3, 4>(params.shape_O_partial);
  386. Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O)), shape_O, params.stride_O);
  387. Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int<kGmemElemsPerLoad>>{});
  388. GmemTiledCopy gmem_tiled_copy_O;
  389. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  390. #pragma unroll
  391. for (int m = 0; m < size<1>(tOcO); ++m) {
  392. if (tObidb(m) >= 0) {
  393. #pragma unroll
  394. for (int k = 0; k < size<2>(tOcO); ++k) {
  395. int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
  396. if (Is_even_K || tOpO(k)) {
  397. cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m), tObidb(m)));
  398. }
  399. }
  400. }
  401. }
  402. }
  403. };
  404. } // namespace flash