flash_bwd_launch_template.h 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/device_kernel.h" // For device_kernel
  7. #include "cutlass/kernel_launch.h" // For kernel_launch
  8. #include "cutlass/cluster_launch.hpp" // For ClusterLauncher
  9. #include "static_switch.h"
  10. #include "flash.h"
  11. #include "flash_bwd_preprocess_kernel.h"
  12. #include "flash_bwd_postprocess_kernel.h"
  13. #include "tile_scheduler.hpp"
  14. #include "mainloop_bwd_sm90_tma_gmma_ws.hpp"
  15. #include "mainloop_bwd_sm80.hpp"
  16. #include "epilogue_bwd.hpp"
  17. #include "flash_bwd_kernel_sm90.h"
  18. #include "flash_bwd_kernel_sm80.h"
  19. using namespace cute;
  20. template <int Arch, int kHeadDim, int kBlockM, int kBlockN, typename Element,
  21. bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool Deterministic, bool GQA,
  22. int Stages_dO=2, int Stages_dS_or_QSm80=2,
  23. bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
  24. int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  25. bool V_in_regs=false>
  26. void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  27. static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
  28. using ElementAccum = float;
  29. using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
  30. int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM);
  31. int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN);
  32. bool const is_varlen_q = params.cu_seqlens_q;
  33. bool const is_varlen_k = params.cu_seqlens_k;
  34. int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
  35. int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k;
  36. int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded;
  37. int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded;
  38. int batch_q = !is_varlen_q ? params.b : 1;
  39. int batch_k = !is_varlen_k ? params.b : 1;
  40. using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
  41. using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, ArchTag, /*Clear_dQaccum=*/true, Varlen>;
  42. typename PreprocessKernel::Arguments preprocess_args {
  43. static_cast<Element const*>(params.o_ptr),
  44. {seqlen_q, params.d, params.h, batch_q}, // shape_O
  45. {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O
  46. static_cast<Element const*>(params.do_ptr),
  47. {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
  48. static_cast<float*>(params.dsoftmax_sum),
  49. {seqlen_q_rounded, params.h, batch_q}, // shape_dPsum
  50. {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
  51. static_cast<float*>(params.softmax_lse_ptr),
  52. {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE
  53. static_cast<float*>(params.softmax_lse_log2_ptr),
  54. {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
  55. static_cast<ElementAccum*>(params.dq_accum_ptr),
  56. {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
  57. {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum
  58. params.b,
  59. params.dq_semaphore,
  60. params.cu_seqlens_q,
  61. params.seqused_q
  62. };
  63. typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
  64. int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
  65. dim3 grid_m(num_m_block, params.h, params.b);
  66. cutlass::kernel_launch<PreprocessKernel>(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/);
  67. CHECK_CUDA_KERNEL_LAUNCH();
  68. using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  69. using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster
  70. // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80
  71. static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80;
  72. static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1;
  73. using CollectiveMainloop = std::conditional_t<
  74. Arch >= 90,
  75. flash::CollectiveMainloopBwdSm90<Stages, Stages_dO, Stages_dS, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
  76. Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
  77. SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,
  78. flash::CollectiveMainloopBwdSm80<Stages, Stages_dO, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm80,
  79. Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
  80. SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>
  81. >;
  82. using CollectiveEpilogue = std::conditional_t<
  83. !GQA,
  84. flash::CollectiveEpilogueBwd<TileShape_MNK, Element, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, dKV_swapAB, NumMmaWarpGroups * (Arch >= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>,
  85. flash::CollectiveEpilogueBwdGQA<TileShape_MNK, ElementAccum, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, Deterministic>
  86. >;
  87. using Scheduler = flash::SingleTileScheduler<Varlen, false /*Split*/, false /*PackGQA*/, kBlockN>;
  88. using AttnKernel = std::conditional_t<
  89. Arch >= 90,
  90. flash::enable_sm90_or_later<flash::FlashAttnBwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
  91. flash::enable_sm80_to_sm89<flash::FlashAttnBwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
  92. >;
  93. typename CollectiveMainloop::Arguments mainloop_args {
  94. static_cast<Element const*>(params.q_ptr),
  95. {seqlen_q, params.d, params.h, batch_q}, // shape_Q
  96. {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
  97. static_cast<Element const*>(params.k_ptr),
  98. {seqlen_k, params.d, params.h_k, batch_k}, // shape_K
  99. {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
  100. static_cast<Element const*>(params.v_ptr),
  101. {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V
  102. static_cast<Element const*>(params.do_ptr),
  103. {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
  104. static_cast<ElementAccum*>(params.dq_accum_ptr),
  105. {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
  106. {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
  107. static_cast<float*>(params.softmax_lse_log2_ptr),
  108. {seqlen_q_rounded, params.h, batch_q}, // shape_LSE
  109. {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
  110. static_cast<float*>(params.dsoftmax_sum),
  111. {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
  112. params.scale_softmax,
  113. params.window_size_left, params.window_size_right, params.sink_token_length,
  114. params.softcap,
  115. params.b,
  116. params.dq_semaphore,
  117. params.cu_seqlens_q, params.cu_seqlens_k,
  118. params.seqused_q, params.seqused_k
  119. };
  120. // The case work with GQA is ugly but idk how to fix it.
  121. typename CollectiveEpilogue::Arguments epilogue_args {
  122. static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dk_ptr : params.dk_accum_ptr),
  123. [&] {
  124. if constexpr (!GQA) {
  125. return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK
  126. } else {
  127. return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum
  128. }
  129. }(),
  130. [&] {
  131. if constexpr (!GQA) {
  132. return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK
  133. } else {
  134. return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum
  135. }
  136. }(),
  137. static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dv_ptr : params.dv_accum_ptr),
  138. [&] {
  139. if constexpr (!GQA) {
  140. return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV
  141. } else {
  142. return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum
  143. }
  144. }(),
  145. params.h,
  146. params.dk_semaphore,
  147. params.dv_semaphore,
  148. params.cu_seqlens_k,
  149. params.seqused_k,
  150. };
  151. int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{}));
  152. num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{}));
  153. typename flash::TileSchedulerArguments scheduler_args {
  154. num_blocks_n, params.h, params.b, 1 /*num_splits*/,
  155. params.h / params.h_k,
  156. params.seqlen_k,
  157. params.seqlen_q, params.d, sizeof(Element),
  158. params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k
  159. };
  160. int device;
  161. cudaGetDevice(&device);
  162. typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
  163. mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
  164. });
  165. dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
  166. dim3 block_dims = AttnKernel::get_block_shape();
  167. int smem_size = AttnKernel::SharedStorageSize;
  168. // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
  169. // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do));
  170. // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds));
  171. // int smem_size_dqacc = [&] {
  172. // if constexpr (Arch >= 90) {
  173. // return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc));
  174. // } else {
  175. // return 0;
  176. // }
  177. // }();
  178. // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
  179. // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
  180. // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse));
  181. // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum));
  182. // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum);
  183. if constexpr (size(ClusterShape{}) > 1) {
  184. void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
  185. if (smem_size >= 48 * 1024) {
  186. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  187. }
  188. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  189. cutlass::ClusterLauncher::launch(
  190. grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/);
  191. } else {
  192. if (smem_size >= 48 * 1024) {
  193. CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<AttnKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  194. }
  195. cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/);
  196. }
  197. CHECK_CUDA_KERNEL_LAUNCH();
  198. using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, ArchTag,
  199. AttnKernel::CollectiveMainloop::NumMmaThreads,
  200. typename AttnKernel::CollectiveMainloop::TiledMmadQ,
  201. AttnKernel::CollectiveMainloop::dQ_swapAB
  202. >;
  203. typename PostprocessKernel::Arguments postprocess_args {
  204. static_cast<ElementAccum const*>(params.dq_accum_ptr),
  205. {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
  206. {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
  207. static_cast<Element*>(params.dq_ptr),
  208. {seqlen_q, params.d, params.h, batch_q}, // shape_dQ
  209. {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
  210. params.scale_softmax,
  211. params.cu_seqlens_q,
  212. params.seqused_q
  213. };
  214. typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
  215. int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
  216. dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b);
  217. int smem_size_postprocess = PostprocessKernel::SharedStorageSize;
  218. if (smem_size_postprocess >= 48 * 1024) {
  219. CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
  220. }
  221. cutlass::kernel_launch<PostprocessKernel>(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/);
  222. CHECK_CUDA_KERNEL_LAUNCH();
  223. if constexpr (GQA) {
  224. using TileShape_NK = cute::Shape<Int<kBlockN>, Int<kHeadDim>>;
  225. using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_NK, Element, ElementAccum, ArchTag,
  226. AttnKernel::CollectiveEpilogue::NumEpilogueThreads,
  227. typename AttnKernel::CollectiveMainloop::TiledMmadKV,
  228. AttnKernel::CollectiveMainloop::dKV_swapAB
  229. >;
  230. typename PostprocessKerneldKV::Arguments postprocess_dK_args {
  231. static_cast<ElementAccum const*>(params.dk_accum_ptr),
  232. {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum
  233. {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum
  234. static_cast<Element*>(params.dk_ptr),
  235. {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK
  236. {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK
  237. 1.f,
  238. params.cu_seqlens_k,
  239. params.seqused_k
  240. };
  241. typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args);
  242. typename PostprocessKerneldKV::Arguments postprocess_dV_args {
  243. static_cast<ElementAccum const*>(params.dv_accum_ptr),
  244. {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dVaccum
  245. {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum
  246. static_cast<Element*>(params.dv_ptr),
  247. {seqlen_k, params.d, params.h_k, batch_k}, // shape_dV
  248. {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV
  249. 1.f,
  250. params.cu_seqlens_k,
  251. params.seqused_k
  252. };
  253. typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args);
  254. int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{}));
  255. dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b);
  256. int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize;
  257. if (smem_size_postprocess >= 48 * 1024) {
  258. CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKerneldKV>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
  259. }
  260. cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/);
  261. CHECK_CUDA_KERNEL_LAUNCH();
  262. cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/);
  263. CHECK_CUDA_KERNEL_LAUNCH();
  264. }
  265. }
  266. template<int Arch, typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Has_softcap,
  267. int Stages_dO=2, int Stages_dS_or_QSm80=2,
  268. bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
  269. int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  270. bool V_in_regs=false>
  271. void run_mha_bwd_dispatch(Flash_bwd_params &params, cudaStream_t stream) {
  272. VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
  273. BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
  274. // BOOL_SWITCH(params.deterministic, Deterministic, [&] {
  275. // run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
  276. run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
  277. // });
  278. });
  279. });
  280. }
  281. template<int Arch, typename T>
  282. void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
  283. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  284. SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
  285. if constexpr (Arch >= 90) {
  286. if constexpr (Is_causal && Has_softcap) {
  287. // register spill with 128 x 128
  288. run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);
  289. } else {
  290. // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
  291. run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);
  292. }
  293. } else {
  294. run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false>(params, stream);
  295. // Sm86
  296. // run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
  297. // run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
  298. // run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
  299. // run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 1, 8, 4, false>(params, stream);
  300. }
  301. });
  302. });
  303. }
  304. template<int Arch, typename T>
  305. void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
  306. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  307. SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
  308. if constexpr (Arch >= 90) {
  309. run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);
  310. } else {
  311. run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false>(params, stream);
  312. // Sm86
  313. // run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
  314. }
  315. });
  316. });
  317. }
  318. template<int Arch, typename T>
  319. void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
  320. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  321. SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
  322. if constexpr (Arch >= 90) {
  323. if constexpr (Is_causal || Is_local || Has_softcap) {
  324. run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
  325. } else {
  326. run_mha_bwd_dispatch<Arch, T, 80, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
  327. }
  328. } else {
  329. run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false>(params, stream);
  330. // Sm86
  331. // run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true>(params, stream);
  332. }
  333. });
  334. });
  335. }
  336. template<int Arch, typename T>
  337. void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
  338. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  339. SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
  340. if constexpr (Arch >= 90) {
  341. run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
  342. } else {
  343. run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false>(params, stream);
  344. // Sm86
  345. // run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true>(params, stream);
  346. }
  347. });
  348. });
  349. }
  350. template<int Arch, typename T>
  351. void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
  352. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  353. SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
  354. if constexpr (Arch >= 90) {
  355. run_mha_bwd_dispatch<Arch, T, 64, 80, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
  356. } else {
  357. run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 2, 2, false>(params, stream);
  358. // Sm86
  359. // run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true>(params, stream);
  360. // run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 1, 2, true>(params, stream);
  361. }
  362. });
  363. });
  364. }