flash_fwd_kernel.h 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/arch/reg_reconfig.h>
  8. #include <cutlass/array.h>
  9. #include <cutlass/numeric_types.h>
  10. #include <cutlass/numeric_conversion.h>
  11. #include "cutlass/pipeline/pipeline.hpp"
  12. #include "flash.h"
  13. #include "utils.h"
  14. #include "softmax.h"
  15. #include "tile_scheduler.hpp"
  16. #include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
  17. #include "epilogue_fwd_sm90_tma.hpp"
  18. namespace flash {
  19. using namespace cute;
  20. template <typename Ktraits, bool Is_causal, bool Is_local, typename TileScheduler, typename Seqlen_traits, typename Seqlen_traits_Q = Seqlen_traits>
  21. __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
  22. compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, Is_local, Seqlen_traits, Seqlen_traits_Q>::Params const mainloop_params,
  23. CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits, Seqlen_traits_Q>::Params const epilogue_params,
  24. CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params,
  25. Seqlen_traits_Q seqlen_traits_q, Seqlen_traits seqlen_traits_k
  26. ) {
  27. using Element = typename Ktraits::Element;
  28. using TileShape_MNK = typename Ktraits::TileShape_MNK;
  29. using ClusterShape = typename Ktraits::ClusterShape_MNK;
  30. static_assert(Ktraits::Is_WS);
  31. static constexpr bool Is_WS = Ktraits::Is_WS;
  32. static constexpr bool No_smem_O = Ktraits::No_smem_O;
  33. static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
  34. static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
  35. static constexpr int kBlockM = Ktraits::kBlockM;
  36. static constexpr int kBlockH = Ktraits::kBlockH;
  37. // static constexpr int kBlockN = Ktraits::kBlockN;
  38. // static constexpr int kHeadDim = Ktraits::kHeadDim;
  39. using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Is_local, Seqlen_traits, Seqlen_traits_Q>;
  40. using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits_Q>;
  41. using MainloopPipeline = typename Ktraits::MainloopPipeline;
  42. using PipelineParams = typename MainloopPipeline::Params;
  43. using PipelineState = typename MainloopPipeline::PipelineState;
  44. extern __shared__ char shared_memory[];
  45. auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
  46. int const lane_predicate = cute::elect_one_sync();
  47. int const warp_idx = cutlass::canonical_warp_idx_sync();
  48. // Issue Tma Descriptor Prefetch from a single thread
  49. if (warp_idx == 0 && lane_predicate) {
  50. CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
  51. CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
  52. }
  53. // Obtain warp index
  54. int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
  55. PipelineParams pipeline_params;
  56. pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
  57. int warp_group_idx = cutlass::canonical_warp_group_idx();
  58. pipeline_params.role = warp_group_idx == 0
  59. ? MainloopPipeline::ThreadCategory::Producer
  60. : MainloopPipeline::ThreadCategory::Consumer;
  61. pipeline_params.is_leader = warp_group_thread_idx == 0;
  62. pipeline_params.num_consumers = NumMmaThreads;
  63. if (warp_idx == 0 && lane_predicate) {
  64. shared_storage.barrier_Q.init(1 /*numThreads*/);
  65. if constexpr (!No_smem_O) {
  66. shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
  67. // if constexpr(seqlen_traits_q.UseVarSeqLen) {
  68. // shared_storage.barrier_O.init(NumMmaThreads /*numThreads*/);
  69. // } else {
  70. // shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
  71. // }
  72. }
  73. }
  74. // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
  75. MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
  76. MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
  77. CollectiveMainloop collective_mainloop;
  78. CollectiveEpilogue collective_epilogue;
  79. // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
  80. if constexpr (size(ClusterShape{}) > 1) {
  81. cute::cluster_arrive_relaxed();
  82. cute::cluster_wait();
  83. } else {
  84. __syncthreads();
  85. }
  86. // static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16);
  87. static_assert(Ktraits::kNWarps == 8 || Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16);
  88. if (warp_group_idx == 0) { // Producer
  89. cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 24 : 32>();
  90. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  91. if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
  92. PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
  93. PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
  94. int work_idx = 0;
  95. TileScheduler scheduler(&shared_storage.tile_count_semaphore);
  96. for (auto work_tile_info = scheduler.get_initial_work();
  97. work_tile_info.is_valid(scheduler_params);
  98. work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(scheduler_params, work_tile_info)) {
  99. auto block_coord = work_tile_info.get_block_coord(scheduler_params);
  100. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  101. seqlen_traits_q.init(bidb);
  102. seqlen_traits_k.init(bidb);
  103. if constexpr(seqlen_traits_q.UseVarSeqLen) {
  104. // NOTE: to support in future with gqa packed layouts, changed kBlockM to kBlockM/kBlockH
  105. if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) {
  106. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  107. scheduler.broadcast_next_work(work_tile_info);
  108. continue;
  109. }
  110. }
  111. int n_block_min = 0, n_block_max;
  112. collective_mainloop.get_n_block_min_max(
  113. mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k,
  114. n_block_min, n_block_max);
  115. if constexpr (Is_causal || Is_local || seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) {
  116. if(n_block_max <= n_block_min) {
  117. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  118. scheduler.broadcast_next_work(work_tile_info);
  119. continue;
  120. }
  121. }
  122. collective_mainloop.load(
  123. mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
  124. shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx,
  125. seqlen_traits_q, seqlen_traits_k, n_block_min, n_block_max);
  126. ++work_idx;
  127. }
  128. collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v);
  129. }
  130. } else { // Consumer
  131. cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 16 ? 160 : Ktraits::kNWarps == 12 ? 240 : 256>();
  132. TileScheduler scheduler(&shared_storage.tile_count_semaphore);
  133. // Initialize matmul objects.
  134. typename Ktraits::TiledMma1 tiled_mma1;
  135. PipelineState smem_pipe_read_k, smem_pipe_read_v;
  136. // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
  137. // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
  138. collective_mainloop.mma_init();
  139. scheduler.init_consumer();
  140. int work_idx = 0;
  141. CUTLASS_PRAGMA_NO_UNROLL
  142. for (auto work_tile_info = scheduler.get_initial_work();
  143. work_tile_info.is_valid(scheduler_params);
  144. work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
  145. // Attention output (GEMM-II) accumulator.
  146. Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
  147. flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax(mainloop_params.softmax_scale_log2);
  148. auto block_coord = work_tile_info.get_block_coord(scheduler_params);
  149. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  150. seqlen_traits_q.init(bidb);
  151. seqlen_traits_k.init(bidb);
  152. if constexpr(seqlen_traits_q.UseVarSeqLen) {
  153. // NOTE: to support in future with gqa packed layouts, changed kBlockM to kBlockM/kBlockH
  154. if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) {
  155. continue;
  156. }
  157. }
  158. int n_block_max, n_block_min = 0;
  159. collective_mainloop.get_n_block_min_max(
  160. mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k,
  161. n_block_min, n_block_max);
  162. if constexpr (Is_causal || Is_local || seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) {
  163. if(n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE.
  164. if constexpr(!Seqlen_traits_Q::UseGQAPacking) {
  165. collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads,
  166. block_coord, seqlen_traits_q);
  167. } else {
  168. collective_epilogue.store_zero_gqa(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads,
  169. block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod);
  170. }
  171. continue;
  172. }
  173. }
  174. collective_mainloop.mma(
  175. mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v,
  176. tOrO, softmax, n_block_min, n_block_max, threadIdx.x - NumCopyThreads, work_idx,
  177. m_block, shared_storage, seqlen_traits_q, seqlen_traits_k);
  178. // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
  179. collective_epilogue.store(
  180. epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
  181. threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod);
  182. // if constexpr(!No_smem_O && seqlen_traits_q.UseVarSeqLen) { shared_storage.barrier_O.arrive(); }
  183. if constexpr(!No_smem_O && seqlen_traits_q.UseVarSeqLen) {
  184. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::OutputEmpty) /*id*/);
  185. }
  186. ++work_idx;
  187. }
  188. collective_epilogue.store_tail();
  189. }
  190. }
  191. template <typename Ktraits, bool Is_causal, bool Is_local, typename TileScheduler, typename Seqlen_traits, typename Seqlen_traits_Q = Seqlen_traits>
  192. __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
  193. compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, Is_local, Seqlen_traits, Seqlen_traits_Q>::Params const mainloop_params,
  194. CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits, Seqlen_traits_Q>::Params const epilogue_params,
  195. CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params,
  196. Seqlen_traits_Q seqlen_traits_q, Seqlen_traits seqlen_traits_k
  197. ) {
  198. using Element = typename Ktraits::Element;
  199. static_assert(cutlass::sizeof_bits_v<Element> == 8);
  200. using TileShape_MNK = typename Ktraits::TileShape_MNK;
  201. using ClusterShape = typename Ktraits::ClusterShape_MNK;
  202. static_assert(Ktraits::Is_WS);
  203. static constexpr bool Is_WS = Ktraits::Is_WS;
  204. static constexpr bool No_smem_O = Ktraits::No_smem_O;
  205. static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
  206. static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
  207. static constexpr int kBlockM = Ktraits::kBlockM;
  208. static constexpr int kBlockH = Ktraits::kBlockH;
  209. // static constexpr int kBlockN = Ktraits::kBlockN;
  210. // static constexpr int kHeadDim = Ktraits::kHeadDim;
  211. static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128 && Ktraits::kNWarps != 8;
  212. static constexpr bool Use_max_offset = true;
  213. using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Is_local, Seqlen_traits, Seqlen_traits_Q>;
  214. using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits_Q>;
  215. using MainloopPipeline = typename Ktraits::MainloopPipeline;
  216. using MainloopPipelineVt = typename Ktraits::MainloopPipelineNoTMA;
  217. using PipelineParams = typename MainloopPipeline::Params;
  218. using PipelineParamsVt = typename MainloopPipelineVt::Params;
  219. using PipelineState = typename MainloopPipeline::PipelineState;
  220. extern __shared__ char shared_memory[];
  221. auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
  222. int const lane_predicate = cute::elect_one_sync();
  223. int const warp_idx = cutlass::canonical_warp_idx_sync();
  224. // Issue Tma Descriptor Prefetch from a single thread
  225. if (warp_idx == 0 && lane_predicate) {
  226. CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
  227. CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
  228. }
  229. // Obtain warp index
  230. int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
  231. // additional pipeline to synchronize out-of-place smem transpose of V
  232. PipelineParamsVt pipeline_params_vt;
  233. pipeline_params_vt.producer_arv_count = NumCopyThreads;
  234. pipeline_params_vt.consumer_arv_count = NumMmaThreads;
  235. MainloopPipelineVt pipeline_vt(shared_storage.pipeline_vt, pipeline_params_vt);
  236. PipelineParams pipeline_params;
  237. pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
  238. int warp_group_idx = cutlass::canonical_warp_group_idx();
  239. pipeline_params.role = warp_group_idx == 0
  240. ? MainloopPipeline::ThreadCategory::Producer
  241. : MainloopPipeline::ThreadCategory::Consumer;
  242. pipeline_params.is_leader = warp_group_thread_idx == 0;
  243. pipeline_params.num_consumers = NumMmaThreads;
  244. if constexpr(seqlen_traits_q.UseVarSeqLen || seqlen_traits_k.UseVarSeqLen) {
  245. static_assert(size(ClusterShape{}) == 1, "Clusters must be disabled for varlen.");
  246. }
  247. if (warp_idx == 0 && lane_predicate) {
  248. shared_storage.barrier_Q.init(1 /*numThreads*/);
  249. if constexpr (!No_smem_O) {
  250. shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
  251. // if constexpr(seqlen_traits_q.UseVarSeqLen) {
  252. // shared_storage.barrier_O.init(Ktraits::kNWarps - 4 /*numMmaWarps*/);
  253. // shared_storage.barrier_O.init(NumMmaThreads /*numThreads*/);
  254. // } else {
  255. // shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
  256. // }
  257. }
  258. }
  259. // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
  260. MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
  261. // pipeline_v has producer warpgroup for its consumer in fp8 kernel
  262. pipeline_params.num_consumers = NumCopyThreads;
  263. pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
  264. MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
  265. CollectiveMainloop collective_mainloop;
  266. CollectiveEpilogue collective_epilogue;
  267. float descale_q = *mainloop_params.descale_q_ptr;
  268. float descale_k = *mainloop_params.descale_k_ptr;
  269. float descale_v = *mainloop_params.descale_v_ptr;
  270. shared_storage.softmax_scale_qk_log2 = mainloop_params.softmax_scale_log2 * descale_q * descale_k;
  271. shared_storage.descale_v = descale_v;
  272. shared_storage.seqlen_init_k = seqlen_traits_k.UseVarSeqLen || bool(seqlen_traits_k.seq_used);
  273. // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
  274. if constexpr (size(ClusterShape{}) > 1) {
  275. cute::cluster_arrive_relaxed();
  276. cute::cluster_wait();
  277. } else {
  278. __syncthreads();
  279. }
  280. static_assert(Ktraits::kNWarps == 8 || Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16);
  281. if (warp_group_idx == 0) { // Producer
  282. cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 16 ? 32 : Ktraits::kNWarps == 12 ? 40 : 56>();
  283. PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
  284. PipelineState smem_pipe_read, smem_pipe_release;
  285. int work_idx = 0;
  286. TileScheduler scheduler(&shared_storage.tile_count_semaphore);
  287. for (auto work_tile_info = scheduler.get_initial_work();
  288. work_tile_info.is_valid(scheduler_params);
  289. work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(scheduler_params, work_tile_info)) {
  290. // cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumCopyThreads, static_cast<int>(FwdNamedBarriers::NextWorkTile) /*id*/);
  291. auto block_coord = work_tile_info.get_block_coord(scheduler_params);
  292. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  293. if constexpr (seqlen_traits_q.UseVarSeqLen) { seqlen_traits_q.init(bidb); }
  294. if constexpr (seqlen_traits_k.UseVarSeqLen) { seqlen_traits_k.init(bidb); }
  295. else if (shared_storage.seqlen_init_k) { seqlen_traits_k.init_no_guard(bidb); }
  296. if constexpr(seqlen_traits_q.UseVarSeqLen) {
  297. // NOTE: to support in future with gqa packed layout, changed kBlockM to kBlockM/kBlockH
  298. if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) {
  299. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  300. scheduler.broadcast_next_work(work_tile_info);
  301. // need to sync producer warpgroup
  302. cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
  303. continue;
  304. }
  305. }
  306. int n_block_min = 0, n_block_max;
  307. collective_mainloop.get_n_block_min_max(
  308. mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k,
  309. n_block_min, n_block_max);
  310. if constexpr (Is_causal || Is_local ||seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) {
  311. if(n_block_max <= n_block_min) {
  312. scheduler.prefetch_next_work(scheduler_params, work_tile_info);
  313. scheduler.broadcast_next_work(work_tile_info);
  314. // need to sync producer warpgroup
  315. cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
  316. continue;
  317. }
  318. }
  319. collective_mainloop.load_fp8(
  320. mainloop_params, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, smem_pipe_read,
  321. shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx,
  322. seqlen_traits_q, seqlen_traits_k, n_block_min, n_block_max);
  323. ++work_idx;
  324. // don't need to sync producer warpgroup here
  325. // if constexpr (Is_causal) {
  326. // cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/); }
  327. // cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumCopyThreads, static_cast<int>(FwdNamedBarriers::OutputEmpty) /*id*/);
  328. }
  329. collective_mainloop.load_tail_one_write(pipeline_k, pipeline_v, smem_pipe_write);
  330. } else { // Consumer
  331. cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 16 ? 160 : Ktraits::kNWarps == 12 ? 232 : 256>();
  332. TileScheduler scheduler(&shared_storage.tile_count_semaphore);
  333. // Initialize matmul objects.
  334. typename Ktraits::TiledMma1 tiled_mma1;
  335. PipelineState smem_pipe_read;
  336. PipelineState smem_pipe_release;
  337. collective_mainloop.mma_init();
  338. // if constexpr(!No_smem_O && seqlen_traits_q.UseVarSeqLen) {
  339. // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + Ktraits::NumProducerThreads, static_cast<int>(FwdNamedBarriers::OutputEmpty) /*id*/);
  340. // }
  341. scheduler.init_consumer();
  342. int work_idx = 0;
  343. CUTLASS_PRAGMA_NO_UNROLL
  344. for (auto work_tile_info = scheduler.get_initial_work();
  345. work_tile_info.is_valid(scheduler_params);
  346. work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
  347. // cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumCopyThreads, static_cast<int>(FwdNamedBarriers::NextWorkTile) /*id*/);
  348. // Attention output (GEMM-II) accumulator.
  349. Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
  350. flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), Use_max_offset> softmax(shared_storage.softmax_scale_qk_log2);
  351. auto block_coord = work_tile_info.get_block_coord(scheduler_params);
  352. auto [m_block, n_split_idx, bidh, bidb] = block_coord;
  353. if constexpr (seqlen_traits_q.UseVarSeqLen) { seqlen_traits_q.init(bidb); }
  354. if constexpr (seqlen_traits_k.UseVarSeqLen) { seqlen_traits_k.init(bidb); }
  355. else if (shared_storage.seqlen_init_k) { seqlen_traits_k.init_no_guard(bidb); }
  356. if constexpr(seqlen_traits_q.UseVarSeqLen) {
  357. // NOTE: to support in future with gqa packed layout, changed kBlockM to kBlockM/kBlockH
  358. if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) {
  359. continue;
  360. }
  361. }
  362. int n_block_max, n_block_min = 0;
  363. collective_mainloop.get_n_block_min_max(
  364. mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k,
  365. n_block_min, n_block_max);
  366. if constexpr (Is_causal || Is_local || seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) {
  367. if(n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE.
  368. if constexpr(!Seqlen_traits_Q::UseGQAPacking) {
  369. collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads,
  370. block_coord, seqlen_traits_q);
  371. } else {
  372. collective_epilogue.store_zero_gqa(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads,
  373. block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod);
  374. }
  375. continue;
  376. }
  377. }
  378. collective_mainloop.mma_fp8<Delay_V_release>(
  379. mainloop_params, pipeline_k, pipeline_vt, smem_pipe_read, smem_pipe_release,
  380. tOrO, softmax, n_block_min, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block,
  381. shared_storage, seqlen_traits_q, seqlen_traits_k);
  382. collective_epilogue.store(
  383. epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
  384. threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod);
  385. // if constexpr(!No_smem_O && seqlen_traits_q.UseVarSeqLen) {
  386. // int lane_predicate = cute::elect_one_sync();
  387. // if(lane_predicate) { shared_storage.barrier_O.arrive(0 /*cta_id*/, lane_predicate); }
  388. // shared_storage.barrier_O.arrive();
  389. // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + Ktraits::NumProducerThreads, static_cast<int>(FwdNamedBarriers::OutputEmpty) /*id*/);
  390. // cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::OutputEmpty) /*id*/);
  391. // }
  392. ++work_idx;
  393. // cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumCopyThreads, static_cast<int>(FwdNamedBarriers::OutputEmpty) /*id*/);
  394. }
  395. collective_epilogue.store_tail();
  396. }
  397. }
  398. } // namespace flash