epilogue_fwd_sm90_tma.hpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include "cute/tensor.hpp"
  7. #include "cutlass/gemm/collective/collective_builder.hpp"
  8. #include "named_barrier.hpp"
  9. #include "utils.h"
  10. namespace flash {
  11. using namespace cute;
  12. // template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_>
  13. template <typename Ktraits, typename Seqlen_traits>
  14. struct CollectiveEpilogueFwd {
  15. using InputType = typename Ktraits::Element;
  16. using Element = typename Ktraits::OutputType;
  17. static constexpr int kBlockM = Ktraits::kBlockM;
  18. static constexpr int kBlockN = Ktraits::kBlockN;
  19. static constexpr int kBlockH = Ktraits::kBlockH;
  20. static constexpr int kHeadDim = Ktraits::kHeadDim;
  21. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  22. static constexpr int kNWarps = Ktraits::kNWarps;
  23. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  24. static constexpr bool Is_WS = Ktraits::Is_WS;
  25. static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
  26. static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;
  27. static constexpr bool Is_split = Ktraits::Is_split;
  28. static constexpr bool No_smem_O = Ktraits::No_smem_O;
  29. #ifndef NO_FP8_COLUMN_PERMUTE
  30. static constexpr bool epi_column_permute = is_same_v<InputType, cutlass::float_e4m3_t>;
  31. #else
  32. static constexpr bool epi_column_permute = false;
  33. #endif
  34. using GmemShapeOT = std::conditional_t<
  35. Is_split,
  36. typename Seqlen_traits::ShapeOAccumT,
  37. typename Seqlen_traits::ShapeT
  38. >;
  39. using GmemStrideOT = std::conditional_t<
  40. Is_split,
  41. typename Seqlen_traits::StrideOAccumT,
  42. typename Seqlen_traits::StrideT
  43. >;
  44. using GmemLayoutOT = std::conditional_t<
  45. Is_split,
  46. typename Seqlen_traits::LayoutOAccumT,
  47. typename Seqlen_traits::LayoutT
  48. >;
  49. using GmemLayoutLseT = std::conditional_t<
  50. Is_split,
  51. typename Seqlen_traits::LayoutLseAccumT,
  52. typename Seqlen_traits::LayoutLseT
  53. >;
  54. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  55. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  56. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  57. using SmemLayoutOCopy = typename Ktraits::SmemLayoutOCopy;
  58. using TileShapeOCopy = typename Ktraits::TileShapeOCopy;
  59. using SmemCopyAtomO = std::conditional_t<Is_split,
  60. Copy_Atom<UniversalCopy<Element>, Element>, Copy_Atom<cute::SM90_U32x4_STSM_N, Element>>;
  61. using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>;
  62. using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
  63. using TMA_O = decltype(make_tma_copy(
  64. GmemTiledCopyOTMA{},
  65. make_tensor(
  66. make_gmem_ptr(static_cast<Element*>(nullptr)),
  67. GmemShapeOT{},
  68. GmemStrideOT{}
  69. ),
  70. SmemLayoutOCopy{},
  71. TileShapeOCopy{},
  72. _1{})); // no mcast for O
  73. // These are for storing the output tensor without TMA (e.g., for setting output to zero and var-seq-len)
  74. static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
  75. static_assert(kHeadDim % kNumVecElem == 0);
  76. static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
  77. static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
  78. static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
  79. using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
  80. using TiledCopyOThrLayout = decltype(cute::make_layout(
  81. cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
  82. LayoutRight{}));
  83. using TiledCopyOValLayout = decltype(cute::make_layout(
  84. cute::make_shape(_1{}, Int<kNumVecElem>{}),
  85. LayoutRight{}));
  86. using TiledCopyO = decltype(make_tiled_copy(
  87. TiledCopyOAtom{},
  88. TiledCopyOThrLayout{}, // Thr layout
  89. TiledCopyOValLayout{} // Val layout
  90. ));
  91. // used for rmem -> smem O copy in fp8 kernel to undo column permutation
  92. using ThreadLayoutrO = Layout<Shape<_8, Int<kBlockM/16>, _4, _1>,
  93. Stride<_4, _32, _1, _0>>;
  94. using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
  95. Stride<_0, _2, Stride<_4, _1>, _8>>;
  96. using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<Element>, Element>{},
  97. ThreadLayoutrO{}, ValueLayoutrO{}));
  98. using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
  99. using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout<TiledCopyShaperO>{}));
  100. // Host side kernel arguments
  101. struct Arguments {
  102. Element* ptr_O;
  103. GmemLayoutOT const layout_O;
  104. float* ptr_LSE;
  105. GmemLayoutLseT const layout_LSE;
  106. };
  107. // Device side kernel params
  108. struct Params {
  109. Element* ptr_O;
  110. GmemLayoutOT const layout_O;
  111. float* ptr_LSE;
  112. GmemLayoutLseT const layout_LSE;
  113. TMA_O tma_store_O;
  114. };
  115. static Params
  116. to_underlying_arguments(Arguments const& args) {
  117. Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.layout_O);
  118. TMA_O tma_store_O = make_tma_copy(
  119. GmemTiledCopyOTMA{},
  120. mO,
  121. SmemLayoutOCopy{},
  122. TileShapeOCopy{},
  123. _1{}); // no mcast for O
  124. return {args.ptr_O, args.layout_O, args.ptr_LSE, args.layout_LSE, tma_store_O};
  125. }
  126. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  127. CUTLASS_DEVICE
  128. static void prefetch_tma_descriptors(Params const& epilogue_params) {
  129. if constexpr (!Seqlen_traits::UseVarSeqLen && !No_smem_O) {
  130. cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
  131. }
  132. }
  133. template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
  134. CUTLASS_DEVICE void
  135. store(Params const& epilogue_params,
  136. FrgTensorO const& tOrO,
  137. FrgTensorLSE const& lse,
  138. SharedStorage& shared_storage,
  139. TiledMma tiled_mma,
  140. int thread_idx,
  141. cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord,
  142. const Seqlen_traits& seqlen_traits_q,
  143. const cutlass::FastDivmod& qhead_per_khead_divmod
  144. ) {
  145. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  146. const int bidh_kv = qhead_per_khead_divmod.divide(bidh);
  147. const int h_block = bidh % int(qhead_per_khead_divmod);
  148. Tensor tOrO_out = flash::convert_type<Element>(tOrO);
  149. if constexpr(!No_smem_O) {
  150. if constexpr (!epi_column_permute) {
  151. Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
  152. auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
  153. auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
  154. Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
  155. Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  156. // Make sure all WGs have finished reading V
  157. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);
  158. cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
  159. cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
  160. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
  161. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  162. } else {
  163. TiledCopyrO rmem_tiled_copy_O;
  164. Tensor sOacc = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutrO{});
  165. auto rmem_thr_copy_O = rmem_tiled_copy_O.get_thread_slice(thread_idx);
  166. Tensor taccOsO = rmem_thr_copy_O.partition_D(sOacc);
  167. Tensor taccOrO = make_tensor(tOrO_out.data(), shape(taccOsO));
  168. // Make sure all WGs have finished reading V
  169. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);
  170. cute::copy(rmem_tiled_copy_O, taccOrO, taccOsO);
  171. cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
  172. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
  173. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
  174. }
  175. }
  176. Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
  177. Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
  178. auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
  179. Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
  180. static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
  181. static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
  182. // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
  183. Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
  184. CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // 2 * MMA_M
  185. if constexpr(!Seqlen_traits::UseGQAPacking) {
  186. Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor<Is_split>(
  187. mLSE, Shape<Int<kBlockM>>{}, bidh, bidb, n_split_idx)(_, m_block);
  188. if (get<1>(taccOcO_row(_0{})) == 0) {
  189. #pragma unroll
  190. for (int mi = 0; mi < size(lse); ++mi) {
  191. const int row = get<0>(taccOcO_row(mi));
  192. if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) {
  193. gLSE(row) = lse(mi);
  194. }
  195. }
  196. }
  197. } else {
  198. // shape<1>(epilogue_params.layout_O) == h/h_k
  199. // In common case where ceil_div(h/h_k, kBlockH) == 1,
  200. // int(qhead_per_khead_divmod) == 1, bidh_kv == bidh, h_block == 0
  201. const int h_offset = shape<1>(epilogue_params.layout_O) * bidh_kv +
  202. h_block * kBlockH;
  203. const int m_bound = seqlen_traits_q.actual_seq_len - m_block * (kBlockM/kBlockH);
  204. const int h_bound = shape<1>(epilogue_params.layout_O) - h_block * kBlockH;
  205. #pragma unroll
  206. for (int mi = 0; mi < size(lse); ++mi) {
  207. const int row = get<0>(taccOcO_row(mi));
  208. const int h_local = row % kBlockH;
  209. const int m_local = row/kBlockH;
  210. if(h_local < h_bound && m_local < m_bound) {
  211. Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor<Is_split>(mLSE,
  212. Shape<Int<kBlockM/kBlockH>>{}, h_offset + h_local, bidb, n_split_idx)
  213. (_, m_block);
  214. gLSE(m_local) = lse(mi);
  215. }
  216. }
  217. }
  218. if constexpr (No_smem_O) {
  219. flash::write_rmem_to_gmem<Seqlen_traits::UseGQAPacking, epi_column_permute>(
  220. tOrO_out, epilogue_params.ptr_O, epilogue_params.layout_O, TileShapeOCopy{},
  221. m_block, h_block, bidh, bidh_kv, bidb, n_split_idx,
  222. tiled_mma, seqlen_traits_q, thread_idx);
  223. } else {
  224. int write_warp_idx = kNWarps - 1;
  225. if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
  226. cutlass::arch::NamedBarrier::sync(
  227. NumMmaThreads + cutlass::NumThreadsPerWarp,
  228. cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
  229. );
  230. }
  231. TiledCopyO gmem_tiled_copy_O;
  232. Tensor sO_out = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutOCopy{});
  233. if constexpr(!Seqlen_traits::UseGQAPacking) {
  234. flash::write_O<!Seqlen_traits::UseVarSeqLen, No_smem_O, Is_split, NumCopyThreads>(
  235. epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O,
  236. epilogue_params.layout_O, TileShapeOCopy{}, sO_out,
  237. m_block, bidh, bidb, n_split_idx, seqlen_traits_q, write_warp_idx, tiled_mma, tOrO_out
  238. );
  239. } else {
  240. Tensor mO = epilogue_params.tma_store_O.get_tma_tensor(epilogue_params.layout_O.shape());
  241. Tensor gO = seqlen_traits_q.get_o_local_tile_tensor<Is_split>(
  242. mO, TileShapeOCopy{}, bidh_kv, bidb, n_split_idx)
  243. (_, _, _, m_block, h_block); // (bM/bH, bH, K)
  244. auto block_tma_O = epilogue_params.tma_store_O.get_slice(_0{});
  245. Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
  246. Tensor tOsO = block_tma_O.partition_S(sO_out); // (TMA, TMA_M, TMA_K)
  247. int const lane_predicate = cute::elect_one_sync();
  248. int const warp_idx = cutlass::canonical_warp_idx_sync();
  249. if (warp_idx == write_warp_idx && lane_predicate) {
  250. cute::copy(epilogue_params.tma_store_O, tOsO, tOgO);
  251. tma_store_arrive();
  252. }
  253. }
  254. }
  255. }
  256. CUTLASS_DEVICE void
  257. store_tail() {
  258. if constexpr(!No_smem_O) { tma_store_wait<0>(); }
  259. }
  260. // Write 0 to output and -inf to LSE
  261. template<typename SharedStorage>
  262. CUTLASS_DEVICE void
  263. store_zero(
  264. Params const& epilogue_params,
  265. SharedStorage& shared_storage,
  266. int thread_idx,
  267. cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord,
  268. const Seqlen_traits& seqlen_traits_q
  269. ) {
  270. static_assert(!Seqlen_traits::UseGQAPacking, "Don't call store_zero for gqa packed layouts.");
  271. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  272. if constexpr(!Is_split) {
  273. Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O);
  274. Tensor gO = seqlen_traits_q.get_o_local_tile_tensor<Is_split>(
  275. mO, select<0, 2>(TileShape_MNK{}), bidh, bidb, n_split_idx
  276. )(_, _, m_block); // (M, K)
  277. TiledCopyO gmem_tiled_copy_O;
  278. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  279. Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
  280. Tensor tOrO = make_fragment_like(tOgO);
  281. clear(tOrO);
  282. // Construct identity layout for sO
  283. Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  284. // Repeat the partitioning with identity layouts
  285. Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
  286. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
  287. #pragma unroll
  288. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
  289. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  290. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  291. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_traits_q.actual_seq_len - m_block * kBlockM
  292. );
  293. }
  294. Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
  295. Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor<Is_split>(
  296. mLSE, Shape<Int<kBlockM>>{}, bidh, bidb, n_split_idx)(_, m_block);
  297. static_assert(kBlockM <= NumMmaThreads);
  298. if (thread_idx < min(kBlockM, seqlen_traits_q.actual_seq_len - m_block * kBlockM)) {
  299. gLSE(thread_idx) = !Is_split ? INFINITY : -INFINITY;
  300. }
  301. }
  302. // Write 0 to output and -inf to LSE
  303. template<typename SharedStorage>
  304. CUTLASS_DEVICE void
  305. store_zero_gqa(
  306. Params const& epilogue_params,
  307. SharedStorage& shared_storage,
  308. int thread_idx,
  309. cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord,
  310. const Seqlen_traits& seqlen_traits_q,
  311. const cutlass::FastDivmod& qhead_per_khead_divmod
  312. ) {
  313. static_assert(Seqlen_traits::UseGQAPacking, "Special store_zero method for GQA packed layouts.");
  314. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  315. const int bidh_kv = qhead_per_khead_divmod.divide(bidh);
  316. const int h_block = bidh % int(qhead_per_khead_divmod);
  317. const int h_bound = min(shape<1>(epilogue_params.layout_O) - h_block * kBlockH, kBlockH);
  318. const int m_bound = min(seqlen_traits_q.actual_seq_len - m_block * (kBlockM/kBlockH), kBlockM/kBlockH);
  319. if constexpr(!Is_split) {
  320. Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O);
  321. Tensor gO = seqlen_traits_q.get_o_local_tile_tensor<Is_split>(
  322. mO, TileShapeOCopy{}, bidh_kv, bidb, n_split_idx)
  323. (_, _, _, m_block, h_block); // (bM/bH, bH, K)
  324. TiledCopyO gmem_tiled_copy_O;
  325. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  326. if constexpr(kNumRows <= kBlockH) {
  327. // slice into bM/bH and write out zero tiles (bH, K)
  328. Tensor tOgO = gmem_thr_copy_O.partition_D(gO(0,_,_));
  329. Tensor tOrO = make_fragment_like(tOgO);
  330. clear(tOrO);
  331. Tensor cO = cute::make_identity_tensor(select<1, 2>(TileShapeOCopy{}));
  332. Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
  333. // dummy predicate, unused since Is_even_K=true
  334. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
  335. #pragma unroll
  336. for(int m = 0; m < m_bound; ++m) {
  337. tOgO = gmem_thr_copy_O.partition_D(gO(m,_,_));
  338. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true,
  339. /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  340. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, h_bound
  341. );
  342. }
  343. } else {
  344. // slice into bH and write out zero tiles (bM/bH, K)
  345. Tensor tOgO = gmem_thr_copy_O.partition_D(gO(_,0,_));
  346. Tensor tOrO = make_fragment_like(tOgO);
  347. clear(tOrO);
  348. Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShapeOCopy{}));
  349. Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
  350. // dummy predicate, unused since Is_even_K=true
  351. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
  352. #pragma unroll
  353. for(int h = 0; h < h_bound; ++h) {
  354. tOgO = gmem_thr_copy_O.partition_D(gO(_,h,_));
  355. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true,
  356. /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  357. gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, m_bound
  358. );
  359. }
  360. }
  361. }
  362. const int h_offset = shape<1>(epilogue_params.layout_O) * bidh_kv + h_block * kBlockH;
  363. const int thread_idx_h = thread_idx % kBlockH;
  364. const int thread_idx_m = thread_idx / kBlockH;
  365. Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
  366. Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor<Is_split>(
  367. mLSE, Shape<Int<kBlockM/kBlockH>>{}, h_offset + thread_idx_h, bidb, n_split_idx)(_, m_block);
  368. if(thread_idx_h < h_bound && thread_idx_m < m_bound) {
  369. gLSE(thread_idx_m) = !Is_split ? INFINITY : -INFINITY;
  370. }
  371. }
  372. };
  373. } // namespace flash