flash_fwd_kernel.h 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/arch/reg_reconfig.h>
  8. #include <cutlass/array.h>
  9. #include <cutlass/numeric_types.h>
  10. #include <cutlass/numeric_conversion.h>
  11. #include <cutlass/kernel_hardware_info.h>
  12. #include "cutlass/pipeline/pipeline.hpp"
  13. #include "seqlen.h"
  14. #include "utils.h"
  15. #include "softmax.h"
  16. namespace flash {
  17. using namespace cute;
  18. template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
  19. class FlashAttnFwd {
  20. public:
  21. // Type Aliases
  22. using CollectiveMainloop = CollectiveMainloop_;
  23. using CollectiveEpilogue = CollectiveEpilogue_;
  24. static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
  25. static constexpr bool Is_local = CollectiveMainloop::Is_local;
  26. static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
  27. static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
  28. static constexpr bool Varlen = CollectiveMainloop::Varlen;
  29. static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
  30. static constexpr bool Split = CollectiveMainloop::Split;
  31. static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
  32. static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
  33. static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
  34. static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q;
  35. static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV;
  36. static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O;
  37. static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
  38. static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
  39. using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
  40. // Mainloop derived types
  41. using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
  42. using TiledMma0 = typename CollectiveMainloop::TiledMma0;
  43. using TiledMma1 = typename CollectiveMainloop::TiledMma1;
  44. using ArchTag = typename CollectiveMainloop::ArchTag;
  45. using ClusterShape = typename CollectiveMainloop::ClusterShape;
  46. using MainloopArguments = typename CollectiveMainloop::Arguments;
  47. using MainloopParams = typename CollectiveMainloop::Params;
  48. using BarrierQ = std::conditional_t<Use_TMA_Q, cutlass::arch::ClusterTransactionBarrier, cutlass::arch::ClusterBarrier>;
  49. // Epilogue derived types
  50. using EpilogueArguments = typename CollectiveEpilogue::Arguments;
  51. using EpilogueParams = typename CollectiveEpilogue::Params;
  52. static_assert(ArchTag::kMinComputeCapability >= 90);
  53. using TileScheduler = TileScheduler_;
  54. using TileSchedulerArguments = typename flash::TileSchedulerArguments;
  55. using TileSchedulerParams = typename TileScheduler::Params;
  56. static constexpr uint32_t NumLoadWarpGroups = 1;
  57. static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma0{})) / cutlass::NumThreadsPerWarpGroup;
  58. static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma0{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
  59. static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
  60. static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
  61. /// Register requirement for Load and Math WGs
  62. // If we use cp.async to load K and V, we need more registers for the producer WG.
  63. static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);
  64. static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);
  65. // If you want to print from the producer warp, you'd need to increase the number of registers
  66. // Otherwise you'll get CUDA error.
  67. // static constexpr uint32_t LoadRegisterRequirement = 40;
  68. // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
  69. // Kernel level shared memory storage
  70. // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v
  71. // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v).
  72. static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)));
  73. static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
  74. struct SharedStorage {
  75. struct TensorStorage : cute::aligned_struct<128> {
  76. union {
  77. struct {
  78. cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
  79. typename CollectiveMainloop::TensorStorage mainloop;
  80. };
  81. // We want smem_o to line up with the start of smem_v
  82. typename CollectiveEpilogue::TensorStorage epilogue;
  83. };
  84. } tensors;
  85. struct PipelineStorage : cute::aligned_struct<16> {
  86. alignas(16) BarrierQ barrier_Q;
  87. alignas(16) cutlass::arch::ClusterBarrier barrier_O;
  88. alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;
  89. alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;
  90. alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt;
  91. alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new;
  92. alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new;
  93. alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
  94. } pipelines;
  95. };
  96. static constexpr int SharedStorageSize = sizeof(SharedStorage);
  97. // Device side arguments
  98. struct Arguments {
  99. MainloopArguments mainloop{};
  100. EpilogueArguments epilogue{};
  101. cutlass::KernelHardwareInfo hw_info{};
  102. TileSchedulerArguments scheduler{};
  103. };
  104. // Kernel entry point API
  105. struct Params {
  106. MainloopParams mainloop{};
  107. EpilogueParams epilogue{};
  108. cutlass::KernelHardwareInfo hw_info{};
  109. TileSchedulerParams scheduler{};
  110. };
  111. //
  112. // Methods
  113. //
  114. // Convert to underlying arguments. In this case, a simple copy for the aliased type.
  115. static
  116. Params
  117. to_underlying_arguments(Arguments const& args) {
  118. CUTLASS_TRACE_HOST("to_underlying_arguments():");
  119. // Get SM count if needed, otherwise use user supplied SM count
  120. int sm_count = args.hw_info.sm_count;
  121. if (sm_count <= 0) {
  122. CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
  123. " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
  124. sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
  125. }
  126. CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
  127. cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
  128. return {
  129. CollectiveMainloop::to_underlying_arguments(args.mainloop),
  130. CollectiveEpilogue::to_underlying_arguments(args.epilogue),
  131. hw_info,
  132. TileScheduler::to_underlying_arguments(args.scheduler)
  133. };
  134. }
  135. // Computes the kernel launch grid shape based on runtime parameters
  136. static dim3
  137. get_grid_shape(Params const& params) {
  138. return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
  139. }
  140. static dim3
  141. get_block_shape() {
  142. return dim3(MaxThreadsPerBlock, 1, 1);
  143. }
  144. CUTLASS_DEVICE
  145. void
  146. operator()(Params const& params, char* smem_buf) {
  147. static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
  148. static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
  149. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  150. using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;
  151. using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;
  152. using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt;
  153. using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew;
  154. using PipelineState = typename CollectiveMainloop::PipelineState;
  155. using PipelineParamsK = typename MainloopPipelineK::Params;
  156. using PipelineParamsV = typename MainloopPipelineV::Params;
  157. using PipelineParamsVt = typename MainloopPipelineVt::Params;
  158. using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params;
  159. SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
  160. int const lane_predicate = cute::elect_one_sync();
  161. int const warp_idx = cutlass::canonical_warp_idx_sync();
  162. // Issue Tma Descriptor Prefetch from a single thread
  163. if (warp_idx == 0 && lane_predicate) {
  164. CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
  165. CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
  166. }
  167. // Obtain warp index
  168. int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
  169. int warp_group_idx = cutlass::canonical_warp_group_idx();
  170. if (warp_idx == 0 && lane_predicate) {
  171. shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumMmaThreads /*numThreads*/);
  172. shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/);
  173. }
  174. // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
  175. PipelineParamsK pipeline_params_k;
  176. pipeline_params_k.role = warp_group_idx == 0
  177. ? MainloopPipelineK::ThreadCategory::Producer
  178. : MainloopPipelineK::ThreadCategory::Consumer;
  179. if constexpr (Use_TMA_KV) {
  180. pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
  181. pipeline_params_k.is_leader = warp_group_thread_idx == 0;
  182. pipeline_params_k.num_consumers = NumMmaThreads;
  183. } else {
  184. pipeline_params_k.consumer_arv_count = NumMmaThreads;
  185. pipeline_params_k.producer_arv_count = NumProducerThreads;
  186. }
  187. MainloopPipelineK pipeline_k = [&] {
  188. if constexpr (Use_TMA_KV) {
  189. return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});
  190. } else {
  191. return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k);
  192. }
  193. }();
  194. // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{});
  195. MainloopPipelineV pipeline_v = [&] {
  196. if constexpr (!Transpose_V) {
  197. static_assert(is_same_v<PipelineParamsK, PipelineParamsV>);
  198. if constexpr (Use_TMA_KV) {
  199. return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k, ClusterShape{});
  200. } else {
  201. return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k);
  202. }
  203. } else {
  204. PipelineParamsV pipeline_params_v;
  205. pipeline_params_v.role = warp_group_idx == 0
  206. ? MainloopPipelineV::ThreadCategory::Producer
  207. : MainloopPipelineV::ThreadCategory::Consumer;
  208. pipeline_params_v.producer_arv_count = NumProducerThreads;
  209. pipeline_params_v.consumer_arv_count = NumMmaThreads;
  210. return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v);
  211. }
  212. }();
  213. static_assert(is_same_v<PipelineParamsK, PipelineParamsVt>);
  214. // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then
  215. // the producer WG will read from pipeline_vt and write to pipeline_v.
  216. // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used.
  217. // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers.
  218. // However, the thread role isn't used in the pipeline implementation.
  219. MainloopPipelineVt pipeline_vt = [&] {
  220. if constexpr (Use_TMA_KV) {
  221. pipeline_params_k.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG
  222. return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_k, ClusterShape{});
  223. } else {
  224. pipeline_params_k.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG
  225. return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_k);
  226. }
  227. }();
  228. PipelineParamsKVNew pipeline_params_kv_new;
  229. pipeline_params_kv_new.role = warp_group_idx == 0
  230. ? MainloopPipelineKVNew::ThreadCategory::Producer
  231. : MainloopPipelineKVNew::ThreadCategory::Consumer;
  232. pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
  233. pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0;
  234. pipeline_params_kv_new.num_consumers = NumMmaThreads;
  235. auto pipeline_k_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
  236. auto pipeline_v_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
  237. CollectiveMainloop collective_mainloop;
  238. CollectiveEpilogue collective_epilogue;
  239. // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
  240. if constexpr (size(ClusterShape{}) > 1) {
  241. cute::cluster_arrive_relaxed();
  242. cute::cluster_wait();
  243. } else {
  244. __syncthreads();
  245. }
  246. if (warp_group_idx == 0) { // Producer
  247. cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
  248. PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
  249. int work_idx = 0;
  250. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
  251. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  252. static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
  253. if constexpr (SingleProducerWarp) {
  254. if (warp_idx_in_warpgroup != 0) { return; }
  255. }
  256. if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); }
  257. // Load Q, K, V
  258. for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
  259. work_tile_info.is_valid(params.scheduler);
  260. work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
  261. auto block_coord = work_tile_info.get_block_coord(params.scheduler);
  262. SeqlenInfo_t seqlen_info{
  263. get<2>(block_coord) /*bidb*/,
  264. get<0>(params.mainloop.shape_Q),
  265. !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
  266. get<0>(params.mainloop.shape_K_new),
  267. params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
  268. params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
  269. };
  270. if constexpr (AppendKV) {
  271. bool tile_new_valid = collective_mainloop.load_kv_new(
  272. params.mainloop, pipeline_k_new, pipeline_v_new,
  273. smem_pipe_write, shared_storage, seqlen_info, block_coord, work_idx);
  274. if (tile_new_valid) {
  275. // if (threadIdx.x == 0) { printf("Producer: Before sync\n"); }
  276. cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::AppendKV) /*id*/);
  277. // if (threadIdx.x == 0) { printf("Producer: After sync\n"); }
  278. // If we don't reset the state, the loads for the main attention might have the wrong phase.
  279. smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
  280. }
  281. }
  282. auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
  283. scheduler.prefetch_next_work(params.scheduler, work_tile_info);
  284. };
  285. // pipeline_vt won't be used if we don't need to transpose V.
  286. collective_mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write,
  287. shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx);
  288. }
  289. collective_mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx);
  290. } else { // Consumer
  291. cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
  292. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
  293. // Initialize matmul objects.
  294. TiledMma1 tiled_mma1;
  295. PipelineState smem_pipe_read;
  296. // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
  297. // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
  298. scheduler.init_consumer();
  299. collective_mainloop.mma_init();
  300. int work_idx = 0;
  301. CUTLASS_PRAGMA_NO_UNROLL
  302. for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
  303. work_tile_info.is_valid(params.scheduler);
  304. work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
  305. // Attention output (GEMM-II) accumulator.
  306. Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
  307. float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
  308. // If there's tanh softcap, the scaling will be done before tanh.
  309. auto block_coord = work_tile_info.get_block_coord(params.scheduler);
  310. int const bidb = get<2>(block_coord);
  311. if constexpr (Is_FP8 && !Has_softcap) {
  312. int const bidh = get<1>(block_coord);
  313. int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
  314. float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
  315. float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
  316. softmax_scale_log2 *= q_descale * k_descale;
  317. }
  318. flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
  319. SeqlenInfo_t seqlen_info{
  320. bidb,
  321. get<0>(params.mainloop.shape_Q),
  322. !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
  323. get<0>(params.mainloop.shape_K_new),
  324. params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
  325. params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
  326. };
  327. if constexpr (AppendKV) {
  328. bool tile_new_valid = collective_mainloop.store_kv_new(
  329. params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read,
  330. threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord);
  331. if (tile_new_valid) {
  332. // if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); }
  333. // We need this sync so that the gmem write from the consumers is visible to the producer
  334. // that might do TMA read after that.
  335. asm volatile ("fence.proxy.async.global;");
  336. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::AppendKV) /*id*/);
  337. // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); }
  338. smem_pipe_read = PipelineState{};
  339. }
  340. }
  341. bool tile_valid = collective_mainloop.mma(
  342. params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
  343. tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
  344. if (tile_valid) {
  345. // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
  346. collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
  347. threadIdx.x - MmaThreadOffset, block_coord);
  348. } else {
  349. // Write 0 to gO and -inf to gLSE.
  350. // If Split, we don't have to write 0 to O if the mha_combine kernel is used, since it will
  351. // not use the value of O if LSE is -inf.
  352. collective_epilogue.template store_zero<!Split /*Clear_O*/>(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);
  353. // collective_epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);
  354. }
  355. }
  356. collective_epilogue.store_tail();
  357. }
  358. }
  359. };
  360. } // namespace flash