flash_bwd_kernel_sm90.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/arch/reg_reconfig.h>
  8. #include <cutlass/array.h>
  9. #include <cutlass/numeric_types.h>
  10. #include <cutlass/numeric_conversion.h>
  11. #include <cutlass/kernel_hardware_info.h>
  12. #include "cutlass/pipeline/pipeline.hpp"
  13. #include "utils.h"
  14. namespace flash {
  15. using namespace cute;
  16. template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
  17. class FlashAttnBwdSm90 {
  18. public:
  19. // Type Aliases
  20. static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
  21. static constexpr bool Is_local = CollectiveMainloop_::Is_local;
  22. static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
  23. static constexpr bool Varlen = CollectiveMainloop_::Varlen;
  24. // Mainloop derived types
  25. using CollectiveMainloop = CollectiveMainloop_;
  26. using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
  27. using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
  28. using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
  29. using ArchTag = typename CollectiveMainloop::ArchTag;
  30. using ClusterShape = typename CollectiveMainloop::ClusterShape;
  31. using MainloopArguments = typename CollectiveMainloop::Arguments;
  32. using MainloopParams = typename CollectiveMainloop::Params;
  33. static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
  34. // Epilogue derived types
  35. using CollectiveEpilogue = CollectiveEpilogue_;
  36. using EpilogueArguments = typename CollectiveEpilogue::Arguments;
  37. using EpilogueParams = typename CollectiveEpilogue::Params;
  38. static_assert(ArchTag::kMinComputeCapability >= 90);
  39. using TileScheduler = TileScheduler_;
  40. using TileSchedulerArguments = typename flash::TileSchedulerArguments;
  41. using TileSchedulerParams = typename TileScheduler::Params;
  42. static constexpr uint32_t NumLoadWarpGroups = 1;
  43. static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup;
  44. static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
  45. static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
  46. static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
  47. /// Register requirement for Load and Math WGs
  48. static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32;
  49. static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160;
  50. // If you want to print from the producer warp, you'd need to increase the number of registers
  51. // Otherwise you'll get CUDA error.
  52. // static constexpr uint32_t LoadRegisterRequirement = 40;
  53. // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
  54. // Kernel level shared memory storage
  55. struct SharedStorage {
  56. struct TensorStorage : cute::aligned_struct<128> {
  57. union {
  58. typename CollectiveMainloop::TensorStorage mainloop;
  59. typename CollectiveEpilogue::TensorStorage epilogue;
  60. };
  61. } tensors;
  62. struct PipelineStorage : cute::aligned_struct<16> {
  63. alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV;
  64. alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q;
  65. alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do;
  66. alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
  67. } pipelines;
  68. };
  69. static constexpr int SharedStorageSize = sizeof(SharedStorage);
  70. // Device side arguments
  71. struct Arguments {
  72. MainloopArguments mainloop{};
  73. EpilogueArguments epilogue{};
  74. cutlass::KernelHardwareInfo hw_info{};
  75. TileSchedulerArguments scheduler{};
  76. };
  77. // Kernel entry point API
  78. struct Params {
  79. MainloopParams mainloop{};
  80. EpilogueParams epilogue{};
  81. cutlass::KernelHardwareInfo hw_info{};
  82. TileSchedulerParams scheduler{};
  83. };
  84. //
  85. // Methods
  86. //
  87. // Convert to underlying arguments. In this case, a simple copy for the aliased type.
  88. static
  89. Params
  90. to_underlying_arguments(Arguments const& args) {
  91. CUTLASS_TRACE_HOST("to_underlying_arguments():");
  92. // Get SM count if needed, otherwise use user supplied SM count
  93. int sm_count = args.hw_info.sm_count;
  94. if (sm_count <= 0) {
  95. CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
  96. " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
  97. sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
  98. }
  99. CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
  100. cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
  101. return {
  102. CollectiveMainloop::to_underlying_arguments(args.mainloop),
  103. CollectiveEpilogue::to_underlying_arguments(args.epilogue),
  104. hw_info,
  105. TileScheduler::to_underlying_arguments(args.scheduler)
  106. };
  107. }
  108. // Computes the kernel launch grid shape based on runtime parameters
  109. static dim3
  110. get_grid_shape(Params const& params) {
  111. return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
  112. }
  113. static dim3
  114. get_block_shape() {
  115. return dim3(MaxThreadsPerBlock, 1, 1);
  116. }
  117. CUTLASS_DEVICE
  118. void
  119. operator()(Params const& params, char* smem_buf) {
  120. static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
  121. static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
  122. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  123. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  124. using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
  125. using PipelineParams = typename MainloopPipeline::Params;
  126. using PipelineState = typename MainloopPipeline::PipelineState;
  127. using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO;
  128. using PipelineParams_dO = typename MainloopPipeline_dO::Params;
  129. using PipelineState_dO = typename MainloopPipeline_dO::PipelineState;
  130. static constexpr bool Q_dO_same_stages = std::is_same_v<MainloopPipeline, MainloopPipeline_dO>;
  131. SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
  132. int const lane_predicate = cute::elect_one_sync();
  133. int const warp_idx = cutlass::canonical_warp_idx_sync();
  134. // Issue Tma Descriptor Prefetch from a single thread
  135. if (warp_idx == 0 && lane_predicate) {
  136. CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
  137. CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
  138. }
  139. // Obtain warp index
  140. int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
  141. PipelineParams pipeline_params;
  142. pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE;
  143. int warp_group_idx = cutlass::canonical_warp_group_idx();
  144. pipeline_params.role = warp_group_idx == 0
  145. ? MainloopPipeline::ThreadCategory::Producer
  146. : MainloopPipeline::ThreadCategory::Consumer;
  147. pipeline_params.is_leader = warp_group_thread_idx == 0;
  148. pipeline_params.num_consumers = NumMmaThreads;
  149. if (warp_idx == 0 && lane_predicate) {
  150. shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/);
  151. }
  152. // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init();
  153. MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{});
  154. auto role_dO = warp_group_idx == 0
  155. ? MainloopPipeline_dO::ThreadCategory::Producer
  156. : MainloopPipeline_dO::ThreadCategory::Consumer;
  157. PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers};
  158. MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return<Q_dO_same_stages>(pipeline_params, pipeline_params_dO), ClusterShape{});
  159. CollectiveMainloop collective_mainloop;
  160. CollectiveEpilogue collective_epilogue;
  161. // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
  162. if constexpr (size(ClusterShape{}) > 1) {
  163. cute::cluster_arrive_relaxed();
  164. cute::cluster_wait();
  165. } else {
  166. __syncthreads();
  167. }
  168. if (warp_group_idx == 0) { // Producer
  169. cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
  170. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  171. if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO
  172. PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
  173. PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline_dO>();
  174. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
  175. for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler);
  176. work_tile_info.is_valid(params.scheduler);
  177. work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info)) {
  178. auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
  179. auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
  180. cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
  181. auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
  182. scheduler.prefetch_next_work(params.scheduler, work_tile_info);
  183. };
  184. collective_mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write,
  185. smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord);
  186. }
  187. collective_mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do);
  188. } else if (warp_idx_in_warpgroup == 1) {
  189. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
  190. for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
  191. work_tile_info.is_valid(params.scheduler);
  192. work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
  193. auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
  194. auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
  195. cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
  196. collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord);
  197. }
  198. }
  199. } else { // Consumer
  200. cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
  201. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
  202. // Initialize matmul objects.
  203. TiledMmadKV tiled_mma_dKV;
  204. PipelineState smem_pipe_read;
  205. PipelineState_dO smem_pipe_read_do;
  206. collective_mainloop.mma_init();
  207. scheduler.init_consumer();
  208. int work_idx = 0;
  209. CUTLASS_PRAGMA_NO_UNROLL
  210. for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
  211. work_tile_info.is_valid(params.scheduler);
  212. work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
  213. auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
  214. auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
  215. cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
  216. // dK and dV output accumulator.
  217. Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
  218. Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
  219. bool tile_valid = collective_mainloop.mma(
  220. params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do,
  221. tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage);
  222. if (tile_valid) {
  223. collective_epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
  224. threadIdx.x - NumCopyThreads, block_coord);
  225. } else {
  226. collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
  227. }
  228. }
  229. collective_epilogue.store_tail();
  230. }
  231. }
  232. };
  233. } // namespace flash