1
0

flash_fwd_kernel.h 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/arch/reg_reconfig.h>
  8. #include <cutlass/array.h>
  9. #include <cutlass/numeric_types.h>
  10. #include <cutlass/numeric_conversion.h>
  11. #include "cutlass/pipeline/pipeline.hpp"
  12. #include "flash.h"
  13. #include "utils.h"
  14. #include "softmax.h"
  15. #include "tile_scheduler.hpp"
  16. #include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
  17. #include "epilogue_fwd_sm90_tma.hpp"
  18. namespace flash {
  19. using namespace cute;
  20. template <typename Ktraits, bool Is_causal, bool Is_local, typename TileScheduler, typename Seqlen_traits, typename Seqlen_traits_Q = Seqlen_traits>
  21. __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
  22. compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, Is_local, Seqlen_traits, Seqlen_traits_Q>::Params const mainloop_params,
  23. CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits, Seqlen_traits_Q>::Params const epilogue_params,
  24. CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params,
  25. Seqlen_traits_Q seqlen_traits_q, Seqlen_traits seqlen_traits_k
  26. ) {
  27. using Element = typename Ktraits::Element;
  28. using TileShape_MNK = typename Ktraits::TileShape_MNK;
  29. using ClusterShape = typename Ktraits::ClusterShape_MNK;
  30. static_assert(Ktraits::Is_WS);
  31. static constexpr bool Is_WS = Ktraits::Is_WS;
  32. static constexpr bool No_smem_O = Ktraits::No_smem_O;
  33. static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
  34. static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
  35. static constexpr int kBlockM = Ktraits::kBlockM;
  36. static constexpr int kBlockH = Ktraits::kBlockH;
  37. // static constexpr int kBlockN = Ktraits::kBlockN;
  38. // static constexpr int kHeadDim = Ktraits::kHeadDim;
  39. using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Is_local, Seqlen_traits, Seqlen_traits_Q>;
  40. using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits_Q>;
  41. using MainloopPipeline = typename Ktraits::MainloopPipeline;
  42. using PipelineParams = typename MainloopPipeline::Params;
  43. using PipelineState = typename MainloopPipeline::PipelineState;
  44. extern __shared__ char shared_memory[];
  45. auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
  46. int const lane_predicate = cute::elect_one_sync();
  47. int const warp_idx = cutlass::canonical_warp_idx_sync();
  48. // Issue Tma Descriptor Prefetch from a single thread
  49. if (warp_idx == 0 && lane_predicate) {
  50. CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
  51. CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
  52. }
  53. // Obtain warp index
  54. int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
  55. PipelineParams pipeline_params;
  56. pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
  57. int warp_group_idx = cutlass::canonical_warp_group_idx();
  58. pipeline_params.role = warp_group_idx == 0
  59. ? MainloopPipeline::ThreadCategory::Producer
  60. : MainloopPipeline::ThreadCategory::Consumer;
  61. pipeline_params.is_leader = warp_group_thread_idx == 0;
  62. pipeline_params.num_consumers = NumMmaThreads;
  63. if (warp_idx == 0 && lane_predicate) {
  64. shared_storage.barrier_Q.init(1 /*numThreads*/);
  65. if constexpr (!No_smem_O) { shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/); }
  66. }
  67. // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
  68. MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
  69. MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
  70. CollectiveMainloop collective_mainloop;
  71. CollectiveEpilogue collective_epilogue;
  72. // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
  73. if constexpr (size(ClusterShape{}) > 1) {
  74. cute::cluster_arrive_relaxed();
  75. cute::cluster_wait();
  76. } else {
  77. __syncthreads();
  78. }
  79. // static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16);
  80. static_assert(Ktraits::kNWarps == 8 || Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16);
  81. if (warp_group_idx == 0) { // Producer
  82. cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 24 : 32>();
  83. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  84. if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
  85. PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
  86. PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
  87. int work_idx = 0;
  88. TileScheduler scheduler(&shared_storage.tile_count_semaphore);
  89. for (auto work_tile_info = scheduler.get_initial_work();
  90. work_tile_info.is_valid(scheduler_params);
  91. work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(scheduler_params, work_tile_info)) {
  92. auto block_coord = work_tile_info.get_block_coord(scheduler_params);
  93. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  94. seqlen_traits_q.init(bidb);
  95. seqlen_traits_k.init(bidb);
  96. if constexpr(seqlen_traits_q.UseVarSeqLen) {
  97. // NOTE: to support in future with gqa packed layouts, changed kBlockM to kBlockM/kBlockH
  98. if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) {
  99. continue;
  100. }
  101. }
  102. int n_block_min = 0, n_block_max;
  103. collective_mainloop.get_n_block_min_max(
  104. mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k,
  105. n_block_min, n_block_max);
  106. if constexpr (Is_causal || Is_local || seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) {
  107. if(n_block_max <= n_block_min) {
  108. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  109. scheduler.broadcast_next_work(work_tile_info);
  110. continue;
  111. }
  112. }
  113. collective_mainloop.load(
  114. mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
  115. shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx,
  116. seqlen_traits_q, seqlen_traits_k, n_block_min, n_block_max);
  117. ++work_idx;
  118. }
  119. collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v);
  120. }
  121. } else { // Consumer
  122. cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 16 ? 160 : Ktraits::kNWarps == 12 ? 240 : 256>();
  123. TileScheduler scheduler(&shared_storage.tile_count_semaphore);
  124. // Initialize matmul objects.
  125. typename Ktraits::TiledMma1 tiled_mma1;
  126. PipelineState smem_pipe_read_k, smem_pipe_read_v;
  127. // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
  128. // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
  129. collective_mainloop.mma_init();
  130. scheduler.init_consumer();
  131. int work_idx = 0;
  132. CUTLASS_PRAGMA_NO_UNROLL
  133. for (auto work_tile_info = scheduler.get_initial_work();
  134. work_tile_info.is_valid(scheduler_params);
  135. work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
  136. // Attention output (GEMM-II) accumulator.
  137. Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
  138. flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax(mainloop_params.softmax_scale_log2);
  139. auto block_coord = work_tile_info.get_block_coord(scheduler_params);
  140. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  141. seqlen_traits_q.init(bidb);
  142. seqlen_traits_k.init(bidb);
  143. if constexpr(seqlen_traits_q.UseVarSeqLen) {
  144. // NOTE: to support in future with gqa packed layouts, changed kBlockM to kBlockM/kBlockH
  145. if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) {
  146. continue;
  147. }
  148. }
  149. int n_block_max, n_block_min = 0;
  150. collective_mainloop.get_n_block_min_max(
  151. mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k,
  152. n_block_min, n_block_max);
  153. if constexpr (Is_causal || Is_local || seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) {
  154. if(n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE.
  155. if constexpr(!Seqlen_traits_Q::UseGQAPacking) {
  156. collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads,
  157. block_coord, seqlen_traits_q);
  158. } else {
  159. collective_epilogue.store_zero_gqa(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads,
  160. block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod);
  161. }
  162. continue;
  163. }
  164. }
  165. collective_mainloop.mma(
  166. mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v,
  167. tOrO, softmax, n_block_min, n_block_max, threadIdx.x - NumCopyThreads, work_idx,
  168. m_block, shared_storage, seqlen_traits_q, seqlen_traits_k);
  169. // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
  170. collective_epilogue.store(
  171. epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
  172. threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod);
  173. ++work_idx;
  174. }
  175. collective_epilogue.store_tail();
  176. }
  177. }
  178. template <typename Ktraits, bool Is_causal, bool Is_local, typename TileScheduler, typename Seqlen_traits, typename Seqlen_traits_Q = Seqlen_traits>
  179. __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
  180. compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, Is_local, Seqlen_traits, Seqlen_traits_Q>::Params const mainloop_params,
  181. CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits, Seqlen_traits_Q>::Params const epilogue_params,
  182. CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params,
  183. Seqlen_traits_Q seqlen_traits_q, Seqlen_traits seqlen_traits_k
  184. ) {
  185. using Element = typename Ktraits::Element;
  186. static_assert(cutlass::sizeof_bits_v<Element> == 8);
  187. using TileShape_MNK = typename Ktraits::TileShape_MNK;
  188. using ClusterShape = typename Ktraits::ClusterShape_MNK;
  189. static_assert(Ktraits::Is_WS);
  190. static constexpr bool Is_WS = Ktraits::Is_WS;
  191. static constexpr bool No_smem_O = Ktraits::No_smem_O;
  192. static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
  193. static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
  194. static constexpr int kBlockM = Ktraits::kBlockM;
  195. static constexpr int kBlockH = Ktraits::kBlockH;
  196. // static constexpr int kBlockN = Ktraits::kBlockN;
  197. // static constexpr int kHeadDim = Ktraits::kHeadDim;
  198. static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128 && Ktraits::kNWarps != 8;
  199. static constexpr bool Use_max_offset = true;
  200. using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Is_local, Seqlen_traits, Seqlen_traits_Q>;
  201. using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits_Q>;
  202. using MainloopPipeline = typename Ktraits::MainloopPipeline;
  203. using MainloopPipelineVt = typename Ktraits::MainloopPipelineNoTMA;
  204. using PipelineParams = typename MainloopPipeline::Params;
  205. using PipelineParamsVt = typename MainloopPipelineVt::Params;
  206. using PipelineState = typename MainloopPipeline::PipelineState;
  207. extern __shared__ char shared_memory[];
  208. auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
  209. int const lane_predicate = cute::elect_one_sync();
  210. int const warp_idx = cutlass::canonical_warp_idx_sync();
  211. // Issue Tma Descriptor Prefetch from a single thread
  212. if (warp_idx == 0 && lane_predicate) {
  213. CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
  214. CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
  215. }
  216. // Obtain warp index
  217. int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
  218. // additional pipeline to synchronize out-of-place smem transpose of V
  219. PipelineParamsVt pipeline_params_vt;
  220. pipeline_params_vt.producer_arv_count = NumCopyThreads;
  221. pipeline_params_vt.consumer_arv_count = NumMmaThreads;
  222. MainloopPipelineVt pipeline_vt(shared_storage.pipeline_vt, pipeline_params_vt);
  223. PipelineParams pipeline_params;
  224. pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
  225. int warp_group_idx = cutlass::canonical_warp_group_idx();
  226. pipeline_params.role = warp_group_idx == 0
  227. ? MainloopPipeline::ThreadCategory::Producer
  228. : MainloopPipeline::ThreadCategory::Consumer;
  229. pipeline_params.is_leader = warp_group_thread_idx == 0;
  230. pipeline_params.num_consumers = NumMmaThreads;
  231. if (warp_idx == 0 && lane_predicate) {
  232. shared_storage.barrier_Q.init(1 /*numThreads*/);
  233. if constexpr (!No_smem_O) { shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/); }
  234. }
  235. // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
  236. MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
  237. // pipeline_v has producer warpgroup for its consumer in fp8 kernel
  238. pipeline_params.num_consumers = NumCopyThreads;
  239. pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
  240. MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
  241. CollectiveMainloop collective_mainloop;
  242. CollectiveEpilogue collective_epilogue;
  243. float descale_q = *mainloop_params.descale_q_ptr;
  244. float descale_k = *mainloop_params.descale_k_ptr;
  245. float descale_v = *mainloop_params.descale_v_ptr;
  246. shared_storage.softmax_scale_qk_log2 = mainloop_params.softmax_scale_log2 * descale_q * descale_k;
  247. shared_storage.descale_v = descale_v;
  248. shared_storage.seqlen_init_k = seqlen_traits_k.UseVarSeqLen || bool(seqlen_traits_k.seq_used);
  249. // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
  250. if constexpr (size(ClusterShape{}) > 1) {
  251. cute::cluster_arrive_relaxed();
  252. cute::cluster_wait();
  253. } else {
  254. __syncthreads();
  255. }
  256. static_assert(Ktraits::kNWarps == 8 || Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16);
  257. if (warp_group_idx == 0) { // Producer
  258. cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 16 ? 32 : Ktraits::kNWarps == 12 ? 40 : 56>();
  259. PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
  260. PipelineState smem_pipe_read, smem_pipe_release;
  261. int work_idx = 0;
  262. TileScheduler scheduler(&shared_storage.tile_count_semaphore);
  263. for (auto work_tile_info = scheduler.get_initial_work();
  264. work_tile_info.is_valid(scheduler_params);
  265. work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(scheduler_params, work_tile_info)) {
  266. auto block_coord = work_tile_info.get_block_coord(scheduler_params);
  267. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  268. if constexpr (seqlen_traits_q.UseVarSeqLen) { seqlen_traits_q.init(bidb); }
  269. if (shared_storage.seqlen_init_k) { seqlen_traits_k.init_no_guard(bidb); }
  270. if constexpr(seqlen_traits_q.UseVarSeqLen) {
  271. // NOTE: to support in future with gqa packed layout, changed kBlockM to kBlockM/kBlockH
  272. if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) {
  273. continue;
  274. }
  275. }
  276. int n_block_min = 0, n_block_max;
  277. collective_mainloop.get_n_block_min_max(
  278. mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k,
  279. n_block_min, n_block_max);
  280. if constexpr (Is_causal || Is_local ||seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) {
  281. if(n_block_max <= n_block_min) {
  282. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  283. scheduler.broadcast_next_work(work_tile_info);
  284. // need to sync producer warpgroup
  285. cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
  286. continue;
  287. }
  288. }
  289. collective_mainloop.load_fp8(
  290. mainloop_params, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, smem_pipe_read,
  291. shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx,
  292. seqlen_traits_q, seqlen_traits_k, n_block_min, n_block_max);
  293. ++work_idx;
  294. // don't need to sync producer warpgroup here
  295. // if constexpr (Is_causal) {
  296. // cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/); }
  297. }
  298. collective_mainloop.load_tail_one_write(pipeline_k, pipeline_v, smem_pipe_write);
  299. } else { // Consumer
  300. cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 16 ? 160 : Ktraits::kNWarps == 12 ? 232 : 256>();
  301. TileScheduler scheduler(&shared_storage.tile_count_semaphore);
  302. // Initialize matmul objects.
  303. typename Ktraits::TiledMma1 tiled_mma1;
  304. PipelineState smem_pipe_read;
  305. PipelineState smem_pipe_release;
  306. collective_mainloop.mma_init();
  307. scheduler.init_consumer();
  308. int work_idx = 0;
  309. CUTLASS_PRAGMA_NO_UNROLL
  310. for (auto work_tile_info = scheduler.get_initial_work();
  311. work_tile_info.is_valid(scheduler_params);
  312. work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
  313. // Attention output (GEMM-II) accumulator.
  314. Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
  315. flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), Use_max_offset> softmax(shared_storage.softmax_scale_qk_log2);
  316. auto block_coord = work_tile_info.get_block_coord(scheduler_params);
  317. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  318. if constexpr (seqlen_traits_q.UseVarSeqLen) { seqlen_traits_q.init(bidb); }
  319. if (shared_storage.seqlen_init_k) { seqlen_traits_k.init_no_guard(bidb); }
  320. if constexpr(seqlen_traits_q.UseVarSeqLen) {
  321. // NOTE: to support in future with gqa packed layout, changed kBlockM to kBlockM/kBlockH
  322. if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) {
  323. continue;
  324. }
  325. }
  326. int n_block_max, n_block_min = 0;
  327. collective_mainloop.get_n_block_min_max(
  328. mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k,
  329. n_block_min, n_block_max);
  330. if constexpr (Is_causal || Is_local || seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) {
  331. if(n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE.
  332. if constexpr(!Seqlen_traits_Q::UseGQAPacking) {
  333. collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads,
  334. block_coord, seqlen_traits_q);
  335. } else {
  336. collective_epilogue.store_zero_gqa(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads,
  337. block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod);
  338. }
  339. continue;
  340. }
  341. }
  342. collective_mainloop.mma_fp8<Delay_V_release>(
  343. mainloop_params, pipeline_k, pipeline_vt, smem_pipe_read, smem_pipe_release,
  344. tOrO, softmax, n_block_min, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block,
  345. shared_storage, seqlen_traits_q, seqlen_traits_k);
  346. collective_epilogue.store(
  347. epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
  348. threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod);
  349. ++work_idx;
  350. }
  351. collective_epilogue.store_tail();
  352. }
  353. }
  354. } // namespace flash