flash_bwd_kernel.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/arch/reg_reconfig.h>
  8. #include <cutlass/array.h>
  9. #include <cutlass/numeric_types.h>
  10. #include <cutlass/numeric_conversion.h>
  11. #include <cutlass/kernel_hardware_info.h>
  12. #include "cutlass/pipeline/pipeline.hpp"
  13. #include "utils.h"
  14. #include "tile_scheduler_bwd.hpp"
  15. #include "mainloop_bwd_sm90_tma_gmma_ws.hpp"
  16. #include "epilogue_bwd_sm90_tma.hpp"
  17. namespace flash {
  18. using namespace cute;
  19. template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
  20. class FlashAttnBwd {
  21. public:
  22. // Type Aliases
  23. static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
  24. static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
  25. static constexpr bool Varlen = CollectiveMainloop_::Varlen;
  26. // Mainloop derived types
  27. using CollectiveMainloop = CollectiveMainloop_;
  28. using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
  29. using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
  30. using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
  31. using ArchTag = typename CollectiveMainloop::ArchTag;
  32. using ClusterShape = typename CollectiveMainloop::ClusterShape;
  33. using MainloopArguments = typename CollectiveMainloop::Arguments;
  34. using MainloopParams = typename CollectiveMainloop::Params;
  35. static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
  36. // Epilogue derived types
  37. using CollectiveEpilogue = CollectiveEpilogue_;
  38. using EpilogueArguments = typename CollectiveEpilogue::Arguments;
  39. using EpilogueParams = typename CollectiveEpilogue::Params;
  40. static_assert(ArchTag::kMinComputeCapability >= 90);
  41. using TileScheduler = TileScheduler_;
  42. using TileSchedulerArguments = typename TileScheduler::Arguments;
  43. using TileSchedulerParams = typename TileScheduler::Params;
  44. static constexpr uint32_t NumLoadWarpGroups = 1;
  45. static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup;
  46. static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
  47. static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
  48. static_assert(NumMmaWarpGroups == 2);
  49. /// Register requirement for Load and Math WGs
  50. static constexpr uint32_t LoadRegisterRequirement = 24;
  51. static constexpr uint32_t MmaRegisterRequirement = 240;
  52. // If you want to print from the producer warp, you'd need to increase the number of registers
  53. // Otherwise you'll get CUDA error.
  54. // static constexpr uint32_t LoadRegisterRequirement = 56;
  55. // static constexpr uint32_t MmaRegisterRequirement = 224;
  56. // Kernel level shared memory storage
  57. struct SharedStorage {
  58. struct {
  59. union {
  60. typename CollectiveMainloop::TensorStorage mainloop;
  61. typename CollectiveEpilogue::TensorStorage epilogue;
  62. };
  63. };
  64. struct {
  65. alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV;
  66. alignas(16) cutlass::arch::ClusterBarrier barrier_dKV;
  67. alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q;
  68. alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_do;
  69. alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
  70. };
  71. };
  72. static constexpr int SharedStorageSize = sizeof(SharedStorage);
  73. // Device side arguments
  74. struct Arguments {
  75. MainloopArguments mainloop{};
  76. EpilogueArguments epilogue{};
  77. cutlass::KernelHardwareInfo hw_info{};
  78. TileSchedulerArguments scheduler{};
  79. };
  80. // Kernel entry point API
  81. struct Params {
  82. MainloopParams mainloop{};
  83. EpilogueParams epilogue{};
  84. cutlass::KernelHardwareInfo hw_info{};
  85. TileSchedulerParams scheduler{};
  86. };
  87. //
  88. // Methods
  89. //
  90. // Convert to underlying arguments. In this case, a simple copy for the aliased type.
  91. static
  92. Params
  93. to_underlying_arguments(Arguments const& args) {
  94. CUTLASS_TRACE_HOST("to_underlying_arguments():");
  95. // Get SM count if needed, otherwise use user supplied SM count
  96. int sm_count = args.hw_info.sm_count;
  97. if (sm_count <= 0) {
  98. CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
  99. " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
  100. sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
  101. }
  102. CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
  103. cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
  104. return {
  105. CollectiveMainloop::to_underlying_arguments(args.mainloop),
  106. CollectiveEpilogue::to_underlying_arguments(args.epilogue),
  107. hw_info,
  108. TileScheduler::to_underlying_arguments(args.scheduler)
  109. };
  110. }
  111. // Computes the kernel launch grid shape based on runtime parameters
  112. static dim3
  113. get_grid_shape(Params const& params) {
  114. return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
  115. }
  116. static dim3
  117. get_block_shape() {
  118. return dim3(MaxThreadsPerBlock, 1, 1);
  119. }
  120. CUTLASS_DEVICE
  121. void
  122. operator()(Params const& params, char* smem_buf) {
  123. static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
  124. static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
  125. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  126. using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
  127. using PipelineParams = typename MainloopPipeline::Params;
  128. using PipelineState = typename MainloopPipeline::PipelineState;
  129. SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
  130. int const lane_predicate = cute::elect_one_sync();
  131. int const warp_idx = cutlass::canonical_warp_idx_sync();
  132. // Issue Tma Descriptor Prefetch from a single thread
  133. if (warp_idx == 0 && lane_predicate) {
  134. CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
  135. CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
  136. }
  137. // Obtain warp index
  138. int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
  139. PipelineParams pipeline_params;
  140. pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE;
  141. int warp_group_idx = cutlass::canonical_warp_group_idx();
  142. pipeline_params.role = warp_group_idx == 0
  143. ? MainloopPipeline::ThreadCategory::Producer
  144. : MainloopPipeline::ThreadCategory::Consumer;
  145. pipeline_params.is_leader = warp_group_thread_idx == 0;
  146. pipeline_params.num_consumers = NumMmaThreads;
  147. if (warp_idx == 0 && lane_predicate) {
  148. shared_storage.barrier_KV.init(1 /*numThreads*/);
  149. // shared_storage.barrier_dKV.init(size(ClusterShape{}) /*numThreads*/);
  150. }
  151. // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init();
  152. MainloopPipeline pipeline_q(shared_storage.pipeline_q, pipeline_params, ClusterShape{});
  153. MainloopPipeline pipeline_do(shared_storage.pipeline_do, pipeline_params, ClusterShape{});
  154. CollectiveMainloop collective_mainloop;
  155. CollectiveEpilogue collective_epilogue;
  156. // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
  157. if constexpr (size(ClusterShape{}) > 1) {
  158. cute::cluster_arrive_relaxed();
  159. cute::cluster_wait();
  160. } else {
  161. __syncthreads();
  162. }
  163. if (warp_group_idx == 0) { // Producer
  164. cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
  165. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  166. if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO
  167. PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
  168. int work_idx = 0;
  169. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
  170. for (auto work_tile_info = scheduler.template get_initial_work</*IsProducer=*/true>(params.scheduler);
  171. work_tile_info.is_valid(params.scheduler);
  172. work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(params.scheduler, work_tile_info)) {
  173. auto block_coord = work_tile_info.get_block_coord(params.scheduler);
  174. auto [n_block, bidh, bidb] = block_coord;
  175. if constexpr (Varlen) {
  176. if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) {
  177. scheduler.prefetch_next_work(params.scheduler, work_tile_info);
  178. continue;
  179. }
  180. }
  181. if constexpr (Is_causal) {
  182. int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
  183. int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
  184. if (m_block_min >= m_block_max) {
  185. scheduler.prefetch_next_work(params.scheduler, work_tile_info);
  186. continue;
  187. }
  188. }
  189. auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
  190. scheduler.prefetch_next_work(params.scheduler, work_tile_info);
  191. };
  192. collective_mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write,
  193. shared_storage, scheduler_prefetch, block_coord, work_idx);
  194. ++work_idx;
  195. }
  196. collective_mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write);
  197. } else if (warp_idx_in_warpgroup == 1) {
  198. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
  199. for (auto work_tile_info = scheduler.template get_initial_work</*IsProducer=*/false>(params.scheduler);
  200. work_tile_info.is_valid(params.scheduler);
  201. work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(params.scheduler, work_tile_info)) {
  202. auto block_coord = work_tile_info.get_block_coord(params.scheduler);
  203. auto [n_block, bidh, bidb] = block_coord;
  204. if constexpr (Varlen) {
  205. if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
  206. }
  207. if constexpr (Is_causal) {
  208. int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
  209. int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
  210. if (m_block_min >= m_block_max) { continue; }
  211. }
  212. collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord);
  213. }
  214. }
  215. } else { // Consumer
  216. cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
  217. TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
  218. // Initialize matmul objects.
  219. TiledMmadKV tiled_mma_dKV;
  220. PipelineState smem_pipe_read;
  221. collective_mainloop.mma_init();
  222. scheduler.init_consumer();
  223. int work_idx = 0;
  224. CUTLASS_PRAGMA_NO_UNROLL
  225. for (auto work_tile_info = scheduler.template get_initial_work</*IsProducer=*/false>(params.scheduler);
  226. work_tile_info.is_valid(params.scheduler);
  227. work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(params.scheduler, work_tile_info)) {
  228. auto block_coord = work_tile_info.get_block_coord(params.scheduler);
  229. auto [n_block, bidh, bidb] = block_coord;
  230. if constexpr (Varlen) {
  231. if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
  232. }
  233. if constexpr (Is_causal) {
  234. int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
  235. int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
  236. if (m_block_min >= m_block_max) { // We exit early and write 0 to dK and dV
  237. collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
  238. continue;
  239. }
  240. }
  241. // dK and dV output accumulator.
  242. Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
  243. Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
  244. collective_mainloop.mma(params.mainloop, pipeline_q, pipeline_do, smem_pipe_read,
  245. tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage);
  246. collective_epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
  247. threadIdx.x - NumCopyThreads, block_coord);
  248. ++work_idx;
  249. }
  250. collective_epilogue.store_tail();
  251. }
  252. }
  253. };
  254. } // namespace flash