flash_fwd_kernel_sm90.h 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/arch/reg_reconfig.h>
  8. #include <cutlass/array.h>
  9. #include <cutlass/numeric_types.h>
  10. #include <cutlass/numeric_conversion.h>
  11. #include <cutlass/kernel_hardware_info.h>
  12. #include "cutlass/pipeline/pipeline.hpp"
  13. #include "seqlen.h"
  14. #include "utils.h"
  15. #include "softmax.h"
  16. namespace flash {
  17. using namespace cute;
  18. template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
  19. class FlashAttnFwdSm90 {
  20. public:
  21. // Type Aliases
  22. using CollectiveMainloop = CollectiveMainloop_;
  23. using CollectiveEpilogue = CollectiveEpilogue_;
  24. static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
  25. static constexpr bool Is_local = CollectiveMainloop::Is_local;
  26. static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
  27. static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
  28. static constexpr bool Varlen = CollectiveMainloop::Varlen;
  29. static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
  30. static constexpr bool Split = CollectiveMainloop::Split;
  31. static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
  32. static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
  33. static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
  34. static constexpr bool HasQv = CollectiveMainloop::HasQv;
  35. static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q;
  36. static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV;
  37. static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O;
  38. static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
  39. static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
  40. static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim;
  41. static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV;
  42. static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV);
  43. using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
  44. // Mainloop derived types
  45. using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV;
  46. using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV;
  47. using ArchTag = typename CollectiveMainloop::ArchTag;
  48. using ClusterShape = typename CollectiveMainloop::ClusterShape;
  49. using MainloopArguments = typename CollectiveMainloop::Arguments;
  50. using MainloopParams = typename CollectiveMainloop::Params;
  51. using BarrierQ = std::conditional_t<Use_TMA_Q, cutlass::arch::ClusterTransactionBarrier, cutlass::arch::ClusterBarrier>;
  52. // Epilogue derived types
  53. using EpilogueArguments = typename CollectiveEpilogue::Arguments;
  54. using EpilogueParams = typename CollectiveEpilogue::Params;
  55. static_assert(ArchTag::kMinComputeCapability >= 90);
  56. using TileScheduler = TileScheduler_;
  57. using TileSchedulerArguments = typename flash::TileSchedulerArguments;
  58. using TileSchedulerParams = typename TileScheduler::Params;
  59. static constexpr uint32_t NumLoadWarpGroups = 1;
  60. static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup;
  61. static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
  62. static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
  63. static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
  64. /// Register requirement for Load and Math WGs
  65. // If we use cp.async to load K and V, we need more registers for the producer WG.
  66. static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);
  67. static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);
  68. // If you want to print from the producer warp, you'd need to increase the number of registers
  69. // Otherwise you'll get CUDA error.
  70. // static constexpr uint32_t LoadRegisterRequirement = 40;
  71. // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
  72. // Kernel level shared memory storage
  73. // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v
  74. // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v).
  75. static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)));
  76. static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
  77. struct SharedStorage {
  78. struct TensorStorage : cute::aligned_struct<128, _1> {
  79. union {
  80. struct {
  81. cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
  82. typename CollectiveMainloop::TensorStorage mainloop;
  83. };
  84. // We want smem_o to line up with the start of smem_v
  85. typename CollectiveEpilogue::TensorStorage epilogue;
  86. };
  87. } tensors;
  88. struct PipelineStorage : cute::aligned_struct<16, _1> {
  89. alignas(16) BarrierQ barrier_Q;
  90. alignas(16) BarrierQ barrier_Qv;
  91. alignas(16) cutlass::arch::ClusterBarrier barrier_O;
  92. alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;
  93. alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;
  94. alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt;
  95. alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new;
  96. alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new;
  97. alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
  98. } pipelines;
  99. };
  100. static constexpr int SharedStorageSize = sizeof(SharedStorage);
  101. // Device side arguments
  102. struct Arguments {
  103. MainloopArguments mainloop{};
  104. EpilogueArguments epilogue{};
  105. cutlass::KernelHardwareInfo hw_info{};
  106. TileSchedulerArguments scheduler{};
  107. };
  108. // Kernel entry point API
  109. struct Params {
  110. MainloopParams mainloop{};
  111. EpilogueParams epilogue{};
  112. cutlass::KernelHardwareInfo hw_info{};
  113. TileSchedulerParams scheduler{};
  114. };
  115. //
  116. // Methods
  117. //
  118. // Convert to underlying arguments. In this case, a simple copy for the aliased type.
  119. static
  120. Params
  121. to_underlying_arguments(Arguments const& args) {
  122. CUTLASS_TRACE_HOST("to_underlying_arguments():");
  123. // Get SM count if needed, otherwise use user supplied SM count
  124. int sm_count = args.hw_info.sm_count;
  125. if (sm_count <= 0) {
  126. CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
  127. " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
  128. sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
  129. }
  130. CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
  131. cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
  132. return {
  133. CollectiveMainloop::to_underlying_arguments(args.mainloop),
  134. CollectiveEpilogue::to_underlying_arguments(args.epilogue),
  135. hw_info,
  136. TileScheduler::to_underlying_arguments(args.scheduler)
  137. };
  138. }
  139. // Computes the kernel launch grid shape based on runtime parameters
  140. static dim3
  141. get_grid_shape(Params const& params) {
  142. return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
  143. }
  144. static dim3
  145. get_block_shape() {
  146. return dim3(MaxThreadsPerBlock, 1, 1);
  147. }
  148. CUTLASS_DEVICE
  149. void
  150. operator()(Params const& params, char* smem_buf) {
  151. static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
  152. static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
  153. static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
  154. using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;
  155. using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;
  156. using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt;
  157. using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew;
  158. using PipelineState = typename CollectiveMainloop::PipelineState;
  159. using PipelineParamsK = typename MainloopPipelineK::Params;
  160. using PipelineParamsV = typename MainloopPipelineV::Params;
  161. using PipelineParamsVt = typename MainloopPipelineVt::Params;
  162. using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params;
  163. SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
  164. int const lane_predicate = cute::elect_one_sync();
  165. int const warp_idx = cutlass::canonical_warp_idx_sync();
  166. // Issue Tma Descriptor Prefetch from a single thread
  167. if (warp_idx == 0 && lane_predicate) {
  168. CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
  169. CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
  170. }
  171. // Obtain warp index
  172. int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
  173. int warp_group_idx = cutlass::canonical_warp_group_idx();
  174. if (warp_idx == 0 && lane_predicate) {
  175. shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
  176. if constexpr (HasQv) {
  177. shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
  178. }
  179. shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/);
  180. }
  181. // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
  182. PipelineParamsK pipeline_params_k;
  183. pipeline_params_k.role = warp_group_idx == 0
  184. ? MainloopPipelineK::ThreadCategory::Producer
  185. : MainloopPipelineK::ThreadCategory::Consumer;
  186. if constexpr (Use_TMA_KV) {
  187. pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
  188. pipeline_params_k.is_leader = warp_group_thread_idx == 0;
  189. pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
  190. } else {
  191. pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
  192. pipeline_params_k.producer_arv_count = NumProducerThreads;
  193. }
  194. static_assert(is_same_v<PipelineParamsK, PipelineParamsVt>);
  195. PipelineParamsVt pipeline_params_vt = pipeline_params_k;
  196. if constexpr (Use_TMA_KV && !SameHeadDim) {
  197. pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
  198. if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; }
  199. } else {
  200. if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; }
  201. }
  202. MainloopPipelineK pipeline_k = [&] {
  203. if constexpr (Use_TMA_KV) {
  204. return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});
  205. } else {
  206. return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k);
  207. }
  208. }();
  209. // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{});
  210. MainloopPipelineV pipeline_v = [&] {
  211. if constexpr (!Transpose_V) {
  212. static_assert(is_same_v<PipelineParamsK, PipelineParamsV>);
  213. if constexpr (Use_TMA_KV) {
  214. return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{});
  215. } else {
  216. return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt);
  217. }
  218. } else {
  219. PipelineParamsV pipeline_params_v;
  220. pipeline_params_v.role = warp_group_idx == 0
  221. ? MainloopPipelineV::ThreadCategory::Producer
  222. : MainloopPipelineV::ThreadCategory::Consumer;
  223. pipeline_params_v.producer_arv_count = NumProducerThreads;
  224. pipeline_params_v.consumer_arv_count = NumMmaThreads;
  225. return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v);
  226. }
  227. }();
  228. // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then
  229. // the producer WG will read from pipeline_vt and write to pipeline_v.
  230. // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used.
  231. // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers.
  232. // However, the thread role isn't used in the pipeline implementation.
  233. MainloopPipelineVt pipeline_vt = [&] {
  234. if constexpr (Use_TMA_KV) {
  235. pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG
  236. return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{});
  237. } else {
  238. pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG
  239. return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt);
  240. }
  241. }();
  242. PipelineParamsKVNew pipeline_params_kv_new;
  243. pipeline_params_kv_new.role = warp_group_idx == 0
  244. ? MainloopPipelineKVNew::ThreadCategory::Producer
  245. : MainloopPipelineKVNew::ThreadCategory::Consumer;
  246. pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
  247. pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0;
  248. pipeline_params_kv_new.num_consumers = NumMmaThreads;
  249. auto pipeline_k_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
  250. if constexpr (!SameHeadDim) {
  251. pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
  252. }
  253. auto pipeline_v_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
  254. CollectiveMainloop mainloop;
  255. CollectiveEpilogue epilogue;
  256. // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
  257. if constexpr (size(ClusterShape{}) > 1) {
  258. cute::cluster_arrive_relaxed();
  259. cute::cluster_wait();
  260. } else {
  261. __syncthreads();
  262. }
  263. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
  264. if (warp_group_idx == 0) { // Producer
  265. cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
  266. // The pipelines for AppendKV and main attention are different, since e.g. main attention
  267. // might use cp.async to load KV (if PagedKV) while AppendKV always uses TMA to load
  268. // KV_new. Since the pipeline states are different, we have to manually sync to make
  269. // sure the two pipelines don't race when accessing smem_k and smem_v.
  270. PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
  271. PipelineState smem_pipe_write_new = cutlass::make_producer_start_state<MainloopPipelineKVNew>();
  272. int work_idx = 0;
  273. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  274. static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
  275. if constexpr (SingleProducerWarp) {
  276. if (warp_idx_in_warpgroup != 0) { return; }
  277. }
  278. if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); }
  279. // Load Q, K, V
  280. for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
  281. work_tile_info.is_valid(params.scheduler);
  282. work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
  283. auto block_coord = work_tile_info.get_block_coord(params.scheduler);
  284. SeqlenInfo_t seqlen_info{
  285. get<2>(block_coord) /*bidb*/,
  286. get<0>(params.mainloop.shape_Q),
  287. !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
  288. get<0>(params.mainloop.shape_K_new),
  289. params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
  290. params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
  291. };
  292. if constexpr (AppendKV) {
  293. bool tile_new_valid = mainloop.load_kv_new(
  294. params.mainloop, pipeline_k_new, pipeline_v_new,
  295. smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx);
  296. if (tile_new_valid) {
  297. // if (threadIdx.x == 0) { printf("Producer: Before sync\n"); }
  298. cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
  299. // if (threadIdx.x == 0) { printf("Producer: After sync\n"); }
  300. }
  301. }
  302. auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
  303. scheduler.prefetch_next_work(params.scheduler, work_tile_info);
  304. };
  305. // pipeline_vt won't be used if we don't need to transpose V.
  306. mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write,
  307. shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx);
  308. }
  309. mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx);
  310. } else { // Consumer
  311. cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
  312. // Initialize matmul objects.
  313. TiledMmaPV tiled_mma_pv;
  314. PipelineState smem_pipe_read;
  315. PipelineState smem_pipe_read_new;
  316. // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
  317. // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
  318. scheduler.init_consumer();
  319. mainloop.mma_init();
  320. int work_idx = 0;
  321. CUTLASS_PRAGMA_NO_UNROLL
  322. for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
  323. work_tile_info.is_valid(params.scheduler);
  324. work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
  325. // Attention output (GEMM-II) accumulator.
  326. Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{}));
  327. float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
  328. // If there's tanh softcap, the scaling will be done before tanh.
  329. auto block_coord = work_tile_info.get_block_coord(params.scheduler);
  330. int const bidb = get<2>(block_coord);
  331. if constexpr (Is_FP8 && !Has_softcap) {
  332. int const bidh = get<1>(block_coord);
  333. int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
  334. float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
  335. float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
  336. softmax_scale_log2 *= q_descale * k_descale;
  337. }
  338. flash::Softmax<!LargeHeadDimV ? 2 * (2 * kBlockM / NumMmaThreads) : 2, /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
  339. SeqlenInfo_t seqlen_info{
  340. bidb,
  341. get<0>(params.mainloop.shape_Q),
  342. !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
  343. get<0>(params.mainloop.shape_K_new),
  344. params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
  345. params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
  346. };
  347. if constexpr (AppendKV) {
  348. bool tile_new_valid = mainloop.store_kv_new(
  349. params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new,
  350. threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord);
  351. if (tile_new_valid) {
  352. // if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); }
  353. // We need this sync so that the gmem write from the consumers is visible to the producer
  354. // that might do TMA read after that.
  355. asm volatile ("fence.proxy.async.global;");
  356. cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
  357. // arrive is enough, we don't need sync. The producer will sync, which means
  358. // after that sync we're guaranteed that the AppendKV pipeline have finished
  359. // loading and consumer smem_k and smem_v.
  360. // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); }
  361. }
  362. }
  363. bool tile_valid;
  364. if constexpr (!LargeHeadDimV) {
  365. tile_valid = mainloop.mma(
  366. params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
  367. tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
  368. } else { // mma_pv might not compile if !LargeHeadDimV
  369. if (warp_group_idx == 1) {
  370. tile_valid = mainloop.mma(
  371. params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
  372. tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
  373. } else {
  374. tile_valid = mainloop.mma_pv(
  375. params.mainloop, pipeline_v, smem_pipe_read,
  376. tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage);
  377. }
  378. }
  379. if (tile_valid) {
  380. // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
  381. epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv,
  382. threadIdx.x - MmaThreadOffset, block_coord);
  383. } else {
  384. // Write 0 to gO and -inf to gLSE.
  385. // If Split, we don't have to write 0 to O if the mha_combine kernel is used, since it will
  386. // not use the value of O if LSE is -inf.
  387. epilogue.template store_zero<!Split /*Clear_O*/>(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);
  388. // epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);
  389. }
  390. }
  391. epilogue.store_tail();
  392. }
  393. }
  394. };
  395. } // namespace flash