flash_fwd_kernel.h 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/arch/reg_reconfig.h>
  8. #include <cutlass/array.h>
  9. #include <cutlass/numeric_types.h>
  10. #include <cutlass/numeric_conversion.h>
  11. #include <cutlass/kernel_hardware_info.h>
  12. #include "cutlass/pipeline/pipeline.hpp"
  13. #include "utils.h"
  14. #include "softmax.h"
  15. namespace flash {
  16. using namespace cute;
  17. template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
  18. class FlashAttnFwd {
  19. public:
  20. // Type Aliases
  21. static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
  22. static constexpr bool Is_local = CollectiveMainloop_::Is_local;
  23. static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
  24. static constexpr bool Varlen = CollectiveMainloop_::Varlen;
  25. static constexpr bool Is_FP8 = CollectiveMainloop_::Is_FP8;
  26. // Mainloop derived types
  27. using CollectiveMainloop = CollectiveMainloop_;
  28. using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
  29. using TiledMma0 = typename CollectiveMainloop::TiledMma0;
  30. using TiledMma1 = typename CollectiveMainloop::TiledMma1;
  31. using ArchTag = typename CollectiveMainloop::ArchTag;
  32. using ClusterShape = typename CollectiveMainloop::ClusterShape;
  33. using MainloopArguments = typename CollectiveMainloop::Arguments;
  34. using MainloopParams = typename CollectiveMainloop::Params;
  35. // Epilogue derived types
  36. using CollectiveEpilogue = CollectiveEpilogue_;
  37. using EpilogueArguments = typename CollectiveEpilogue::Arguments;
  38. using EpilogueParams = typename CollectiveEpilogue::Params;
  39. static_assert(ArchTag::kMinComputeCapability >= 90);
  40. using TileScheduler = TileScheduler_;
  41. using TileSchedulerArguments = typename TileScheduler::Arguments;
  42. using TileSchedulerParams = typename TileScheduler::Params;
  43. static constexpr uint32_t NumLoadWarpGroups = 1;
  44. static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma0{})) / cutlass::NumThreadsPerWarpGroup;
  45. static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma0{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
  46. static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
  47. static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
  48. /// Register requirement for Load and Math WGs
  49. static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32;
  50. static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160;
  51. // If you want to print from the producer warp, you'd need to increase the number of registers
  52. // Otherwise you'll get CUDA error.
  53. // static constexpr uint32_t LoadRegisterRequirement = 40;
  54. // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
  55. // Kernel level shared memory storage
  56. struct SharedStorage {
  57. struct TensorStorage : cute::aligned_struct<128> {
  58. union {
  59. typename CollectiveMainloop::TensorStorage mainloop;
  60. // We want smem_o to line up with the start of smem_v
  61. typename CollectiveEpilogue::TensorStorage epilogue;
  62. static_assert(cute::cosize_v<typename CollectiveEpilogue::SmemLayoutO> * sizeof(typename CollectiveEpilogue::Element)
  63. <= cute::cosize_v<typename CollectiveMainloop::SmemLayoutVt> * sizeof(typename CollectiveMainloop::Element));
  64. };
  65. } tensors;
  66. struct PipelineStorage : cute::aligned_struct<16> {
  67. alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_Q;
  68. alignas(16) cutlass::arch::ClusterBarrier barrier_O;
  69. alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;
  70. alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;
  71. alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
  72. } pipelines;
  73. };
  74. static constexpr int SharedStorageSize = sizeof(SharedStorage);
  75. // Device side arguments
  76. struct Arguments {
  77. MainloopArguments mainloop{};
  78. EpilogueArguments epilogue{};
  79. cutlass::KernelHardwareInfo hw_info{};
  80. TileSchedulerArguments scheduler{};
  81. };
  82. // Kernel entry point API
  83. struct Params {
  84. MainloopParams mainloop{};
  85. EpilogueParams epilogue{};
  86. cutlass::KernelHardwareInfo hw_info{};
  87. TileSchedulerParams scheduler{};
  88. };
  89. //
  90. // Methods
  91. //
  92. // Convert to underlying arguments. In this case, a simple copy for the aliased type.
  93. static
  94. Params
  95. to_underlying_arguments(Arguments const& args) {
  96. CUTLASS_TRACE_HOST("to_underlying_arguments():");
  97. // Get SM count if needed, otherwise use user supplied SM count
  98. int sm_count = args.hw_info.sm_count;
  99. if (sm_count <= 0) {
  100. CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
  101. " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
  102. sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
  103. }
  104. CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
  105. cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
  106. return {
  107. CollectiveMainloop::to_underlying_arguments(args.mainloop),
  108. CollectiveEpilogue::to_underlying_arguments(args.epilogue),
  109. hw_info,
  110. TileScheduler::to_underlying_arguments(args.scheduler)
  111. };
  112. }
  113. // Computes the kernel launch grid shape based on runtime parameters
  114. static dim3
  115. get_grid_shape(Params const& params) {
  116. return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
  117. }
  118. static dim3
  119. get_block_shape() {
  120. return dim3(MaxThreadsPerBlock, 1, 1);
  121. }
  122. CUTLASS_DEVICE
  123. void
  124. operator()(Params const& params, char* smem_buf) {
  125. static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
  126. static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
  127. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  128. using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;
  129. using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;
  130. using PipelineState = typename CollectiveMainloop::PipelineState;
  131. using PipelineParamsK = typename MainloopPipelineK::Params;
  132. using PipelineParamsV = typename MainloopPipelineV::Params;
  133. SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
  134. int const lane_predicate = cute::elect_one_sync();
  135. int const warp_idx = cutlass::canonical_warp_idx_sync();
  136. // Issue Tma Descriptor Prefetch from a single thread
  137. if (warp_idx == 0 && lane_predicate) {
  138. CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
  139. CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
  140. }
  141. // Obtain warp index
  142. int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
  143. int warp_group_idx = cutlass::canonical_warp_group_idx();
  144. PipelineParamsK pipeline_params_k;
  145. pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
  146. pipeline_params_k.role = warp_group_idx == 0
  147. ? MainloopPipelineK::ThreadCategory::Producer
  148. : MainloopPipelineK::ThreadCategory::Consumer;
  149. pipeline_params_k.is_leader = warp_group_thread_idx == 0;
  150. pipeline_params_k.num_consumers = NumMmaThreads;
  151. // PipelineParamsV pipeline_params_v;
  152. // pipeline_params_v.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
  153. // pipeline_params_v.role = warp_group_idx == 0
  154. // ? MainloopPipelineV::ThreadCategory::Producer
  155. // : MainloopPipelineV::ThreadCategory::Consumer;
  156. // pipeline_params_v.is_leader = warp_group_thread_idx == 0;
  157. // pipeline_params_v.num_consumers = NumMmaThreads;
  158. if (warp_idx == 0 && lane_predicate) {
  159. shared_storage.pipelines.barrier_Q.init(1 /*numThreads*/);
  160. shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
  161. }
  162. // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
  163. MainloopPipelineK pipeline_k(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});
  164. // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{});
  165. static_assert(is_same_v<PipelineParamsK, PipelineParamsV>);
  166. MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_k, ClusterShape{});
  167. CollectiveMainloop collective_mainloop;
  168. CollectiveEpilogue collective_epilogue;
  169. // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
  170. if constexpr (size(ClusterShape{}) > 1) {
  171. cute::cluster_arrive_relaxed();
  172. cute::cluster_wait();
  173. } else {
  174. __syncthreads();
  175. }
  176. if (warp_group_idx == 0) { // Producer
  177. cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
  178. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  179. if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
  180. PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
  181. int work_idx = 0;
  182. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
  183. for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler);
  184. work_tile_info.is_valid(params.scheduler);
  185. work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info)) {
  186. auto block_coord = work_tile_info.get_block_coord(params.scheduler);
  187. auto [m_block, bidh, bidb] = block_coord;
  188. // With Varlen it's possible to have n_block_max == 0. Loading K can cause illegal memory access.
  189. if constexpr (Is_causal || Is_local || Varlen) {
  190. int n_block_max = collective_mainloop.get_n_block_max(params.mainloop, m_block, bidb);
  191. int n_block_min = collective_mainloop.get_n_block_min(params.mainloop, m_block, bidb);
  192. if (n_block_max <= n_block_min) {
  193. scheduler.prefetch_next_work(params.scheduler, work_tile_info);
  194. continue;
  195. }
  196. }
  197. auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
  198. scheduler.prefetch_next_work(params.scheduler, work_tile_info);
  199. };
  200. collective_mainloop.load(params.mainloop, pipeline_k, pipeline_v, smem_pipe_write,
  201. shared_storage, scheduler_prefetch, block_coord, work_idx);
  202. ++work_idx;
  203. }
  204. collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write);
  205. }
  206. } else { // Consumer
  207. cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
  208. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
  209. // Initialize matmul objects.
  210. TiledMma1 tiled_mma1;
  211. PipelineState smem_pipe_read;
  212. // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
  213. // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
  214. collective_mainloop.mma_init();
  215. scheduler.init_consumer();
  216. int work_idx = 0;
  217. CUTLASS_PRAGMA_NO_UNROLL
  218. for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
  219. work_tile_info.is_valid(params.scheduler);
  220. work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
  221. // Attention output (GEMM-II) accumulator.
  222. Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
  223. float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
  224. if constexpr (Is_FP8) {
  225. float const q_scale = params.mainloop.ptr_q_scale == nullptr ? 1.0f : *params.mainloop.ptr_q_scale;
  226. float const k_scale = params.mainloop.ptr_k_scale == nullptr ? 1.0f : *params.mainloop.ptr_k_scale;
  227. softmax_scale_log2 *= q_scale * k_scale;
  228. }
  229. flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
  230. auto block_coord = work_tile_info.get_block_coord(params.scheduler);
  231. auto [m_block, bidh, bidb] = block_coord;
  232. if constexpr (Is_causal || Is_local || Varlen) {
  233. int n_block_max = collective_mainloop.get_n_block_max(params.mainloop, m_block, bidb);
  234. int n_block_min = collective_mainloop.get_n_block_min(params.mainloop, m_block, bidb);
  235. if (n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE.
  236. collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
  237. continue;
  238. }
  239. }
  240. collective_mainloop.mma(params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
  241. tOrO, softmax, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage);
  242. // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
  243. collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
  244. threadIdx.x - NumCopyThreads, block_coord);
  245. ++work_idx;
  246. }
  247. collective_epilogue.store_tail();
  248. }
  249. }
  250. };
  251. template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_,
  252. class Base=FlashAttnFwd<CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_>>
  253. class FlashAttnFwdFP8TransposeV : public Base {
  254. public:
  255. using CollectiveMainloop = CollectiveMainloop_;
  256. using CollectiveEpilogue = CollectiveEpilogue_;
  257. using TileScheduler = TileScheduler_;
  258. // Type Aliases
  259. static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
  260. static constexpr bool Is_local = CollectiveMainloop_::Is_local;
  261. using TileShape_MNK = typename Base::TileShape_MNK;
  262. using ClusterShape = typename Base::ClusterShape;
  263. using TiledMma1 = typename Base::TiledMma1;
  264. using Params = typename Base::Params;
  265. static constexpr bool Varlen = CollectiveMainloop::Varlen;
  266. static constexpr uint32_t NumLoadWarpGroups = Base::NumLoadWarpGroups;
  267. static constexpr uint32_t NumMmaWarpGroups = Base::NumMmaWarpGroups;
  268. /// Register requirement for Load and Math WGs
  269. static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32;
  270. static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160;
  271. // If you want to print from the producer warp, you'd need to increase the number of registers
  272. // Otherwise you'll get CUDA error.
  273. // static constexpr uint32_t LoadRegisterRequirement = 56;
  274. // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 224 : 152;
  275. // Kernel level shared memory storage
  276. struct SharedStorage {
  277. struct TensorStorage : cute::aligned_struct<128> {
  278. union {
  279. typename CollectiveMainloop::TensorStorage mainloop;
  280. // We want smem_o to line up with the start of smem_v
  281. typename CollectiveEpilogue::TensorStorage epilogue;
  282. static_assert(cute::cosize_v<typename CollectiveEpilogue::SmemLayoutO> <= cute::cosize_v<typename CollectiveMainloop::SmemLayoutVt>);
  283. };
  284. } tensors;
  285. struct PipelineStorage : cute::aligned_struct<16> {
  286. alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_Q;
  287. alignas(16) cutlass::arch::ClusterBarrier barrier_O;
  288. alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;
  289. alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;
  290. alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt;
  291. alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
  292. } pipelines;
  293. };
  294. static constexpr int SharedStorageSize = sizeof(SharedStorage);
  295. CUTLASS_DEVICE
  296. void
  297. operator()(Params const& params, char* smem_buf) {
  298. static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
  299. static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
  300. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  301. using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;
  302. using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;
  303. using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt;
  304. using PipelineStateK = typename MainloopPipelineK::PipelineState;
  305. using PipelineStateV = typename MainloopPipelineV::PipelineState;
  306. using PipelineParamsK = typename MainloopPipelineK::Params;
  307. using PipelineParamsV = typename MainloopPipelineV::Params;
  308. using PipelineParamsVt = typename MainloopPipelineVt::Params;
  309. SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
  310. int const lane_predicate = cute::elect_one_sync();
  311. int const warp_idx = cutlass::canonical_warp_idx_sync();
  312. // Issue Tma Descriptor Prefetch from a single thread
  313. if (warp_idx == 0 && lane_predicate) {
  314. CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
  315. CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
  316. }
  317. // Obtain warp index
  318. int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
  319. int warp_group_idx = cutlass::canonical_warp_group_idx();
  320. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  321. PipelineParamsK pipeline_params_k;
  322. pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
  323. pipeline_params_k.role = warp_group_idx == 0
  324. ? MainloopPipelineK::ThreadCategory::Producer
  325. : MainloopPipelineK::ThreadCategory::Consumer;
  326. pipeline_params_k.is_leader = warp_group_thread_idx == 0;
  327. pipeline_params_k.num_consumers = NumMmaThreads;
  328. // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers.
  329. // However, the thread role isn't used in the pipeline implementation.
  330. PipelineParamsV pipeline_params_v;
  331. pipeline_params_v.role = warp_group_idx == 0
  332. ? MainloopPipelineV::ThreadCategory::Producer
  333. : MainloopPipelineV::ThreadCategory::Consumer;
  334. pipeline_params_v.producer_arv_count = NumCopyThreads;
  335. pipeline_params_v.consumer_arv_count = NumMmaThreads;
  336. if (warp_idx == 0 && lane_predicate) {
  337. shared_storage.pipelines.barrier_Q.init(1 /*numThreads*/);
  338. shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
  339. }
  340. // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
  341. MainloopPipelineK pipeline_k(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});
  342. MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v);
  343. static_assert(is_same_v<MainloopPipelineK, MainloopPipelineVt>);
  344. pipeline_params_k.num_consumers = NumCopyThreads; // TMA_V is only consumed by the producer WG
  345. MainloopPipelineVt pipeline_vt(shared_storage.pipelines.pipeline_vt, pipeline_params_k, ClusterShape{});
  346. CollectiveMainloop collective_mainloop;
  347. CollectiveEpilogue collective_epilogue;
  348. // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
  349. if constexpr (size(ClusterShape{}) > 1) {
  350. cute::cluster_arrive_relaxed();
  351. cute::cluster_wait();
  352. } else {
  353. __syncthreads();
  354. }
  355. if (warp_group_idx == 0) { // Producer
  356. cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
  357. PipelineStateK smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
  358. int work_idx = 0;
  359. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
  360. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  361. if (warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); }
  362. for (auto work_tile_info = warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
  363. work_tile_info.is_valid(params.scheduler);
  364. work_tile_info = warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
  365. auto block_coord = work_tile_info.get_block_coord(params.scheduler);
  366. auto [m_block, bidh, bidb] = block_coord;
  367. // With Varlen it's possible to have n_block_max == 0. Loading K can cause illegal memory access.
  368. if constexpr (Is_causal || Is_local || Varlen) {
  369. int n_block_max = collective_mainloop.get_n_block_max(params.mainloop, m_block, bidb);
  370. int n_block_min = collective_mainloop.get_n_block_min(params.mainloop, m_block, bidb);
  371. if (n_block_max <= n_block_min) {
  372. scheduler.prefetch_next_work(params.scheduler, work_tile_info);
  373. continue;
  374. }
  375. }
  376. auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
  377. scheduler.prefetch_next_work(params.scheduler, work_tile_info);
  378. };
  379. collective_mainloop.load_fp8_transpose_V(params.mainloop, pipeline_k, pipeline_v, pipeline_vt,
  380. smem_pipe_write, shared_storage, scheduler_prefetch, block_coord, work_idx);
  381. ++work_idx;
  382. }
  383. collective_mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write);
  384. } else { // Consumer
  385. cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
  386. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
  387. // Initialize matmul objects.
  388. TiledMma1 tiled_mma1;
  389. PipelineStateK smem_pipe_read_k;
  390. // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
  391. // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
  392. collective_mainloop.mma_init();
  393. scheduler.init_consumer();
  394. int work_idx = 0;
  395. CUTLASS_PRAGMA_NO_UNROLL
  396. for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
  397. work_tile_info.is_valid(params.scheduler);
  398. work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
  399. // Attention output (GEMM-II) accumulator.
  400. Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
  401. float const q_scale = params.mainloop.ptr_q_scale == nullptr ? 1.0f : *params.mainloop.ptr_q_scale;
  402. float const k_scale = params.mainloop.ptr_k_scale == nullptr ? 1.0f : *params.mainloop.ptr_k_scale;
  403. flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), /*Max_offset=*/8> softmax(params.mainloop.softmax_scale_log2 * q_scale * k_scale);
  404. auto block_coord = work_tile_info.get_block_coord(params.scheduler);
  405. auto [m_block, bidh, bidb] = block_coord;
  406. if constexpr (Is_causal || Is_local || Varlen) {
  407. int n_block_max = collective_mainloop.get_n_block_max(params.mainloop, m_block, bidb);
  408. int n_block_min = collective_mainloop.get_n_block_min(params.mainloop, m_block, bidb);
  409. if (n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE.
  410. collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
  411. continue;
  412. }
  413. }
  414. collective_mainloop.mma(params.mainloop, pipeline_k, pipeline_v, smem_pipe_read_k,
  415. tOrO, softmax, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage);
  416. collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
  417. threadIdx.x - NumCopyThreads, block_coord);
  418. ++work_idx;
  419. }
  420. collective_epilogue.store_tail();
  421. }
  422. }
  423. };
  424. } // namespace flash