flash_bwd_launch_template.h 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cluster_launch.hpp"
  7. #include "cutlass/device_kernel.h" // For device_kernel
  8. #include "cutlass/kernel_launch.h" // For kernel_launch
  9. #include "cutlass/cluster_launch.hpp" // For ClusterLauncher
  10. #include "static_switch.h"
  11. #include "flash.h"
  12. #include "flash_bwd_preprocess_kernel.h"
  13. #include "flash_bwd_postprocess_kernel.h"
  14. #include "tile_scheduler.hpp"
  15. #include "mainloop_bwd_sm90_tma_gmma_ws.hpp"
  16. #include "epilogue_bwd_sm90_tma.hpp"
  17. #include "flash_bwd_kernel.h"
  18. using namespace cute;
  19. template <int kHeadDim, int kBlockM, int kBlockN, typename Element,
  20. bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool Deterministic, bool GQA,
  21. int Stages_dO=2, int Stages_dS=2,
  22. bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
  23. int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
  24. void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  25. static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
  26. using ElementAccum = float;
  27. int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM);
  28. int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN);
  29. bool const is_varlen_q = params.cu_seqlens_q;
  30. bool const is_varlen_k = params.cu_seqlens_k;
  31. int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
  32. int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k;
  33. int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded;
  34. int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded;
  35. int batch_q = !is_varlen_q ? params.b : 1;
  36. int batch_k = !is_varlen_k ? params.b : 1;
  37. using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
  38. using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90, /*Clear_dQaccum=*/true, Varlen>;
  39. typename PreprocessKernel::Arguments preprocess_args {
  40. static_cast<Element const*>(params.o_ptr),
  41. {seqlen_q, params.d, params.h, batch_q}, // shape_O
  42. {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O
  43. static_cast<Element const*>(params.do_ptr),
  44. {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
  45. static_cast<float*>(params.dsoftmax_sum),
  46. {seqlen_q_rounded, params.h, batch_q}, // shape_dPsum
  47. {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
  48. static_cast<float*>(params.softmax_lse_ptr),
  49. {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE
  50. static_cast<float*>(params.softmax_lse_log2_ptr),
  51. {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
  52. static_cast<ElementAccum*>(params.dq_accum_ptr),
  53. {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
  54. {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum
  55. params.b,
  56. params.dq_semaphore,
  57. params.cu_seqlens_q,
  58. params.seqused_q
  59. };
  60. typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
  61. int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
  62. dim3 grid_m(num_m_block, params.h, params.b);
  63. cutlass::kernel_launch<PreprocessKernel>(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/);
  64. CHECK_CUDA_KERNEL_LAUNCH();
  65. using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  66. using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster
  67. static constexpr int Stages = 2;
  68. using CollectiveMainloop = flash::CollectiveMainloopBwd<Stages, Stages_dO, Stages_dS, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
  69. Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
  70. SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>;
  71. using CollectiveEpilogue = std::conditional_t<
  72. !GQA,
  73. flash::CollectiveEpilogueBwd<TileShape_MNK, Element, CollectiveMainloop::NumMmaThreads, Varlen, dKV_swapAB, NumMmaWarpGroups / AtomLayoutNdKV>,
  74. flash::CollectiveEpilogueBwdGQA<TileShape_MNK, ElementAccum, CollectiveMainloop::NumMmaThreads, Varlen, Deterministic>
  75. >;
  76. using Scheduler = flash::SingleTileScheduler<Varlen, false /*Split*/, false /*PackGQA*/, kBlockN>;
  77. // using Scheduler = flash::StaticPersistentTileScheduler;
  78. using AttnKernel = flash::FlashAttnBwd<CollectiveMainloop, CollectiveEpilogue, Scheduler>;
  79. typename CollectiveMainloop::Arguments mainloop_args {
  80. static_cast<Element const*>(params.q_ptr),
  81. {seqlen_q, params.d, params.h, batch_q}, // shape_Q
  82. {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
  83. static_cast<Element const*>(params.k_ptr),
  84. {seqlen_k, params.d, params.h_k, batch_k}, // shape_K
  85. {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
  86. static_cast<Element const*>(params.v_ptr),
  87. {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V
  88. static_cast<Element const*>(params.do_ptr),
  89. {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
  90. static_cast<ElementAccum*>(params.dq_accum_ptr),
  91. {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
  92. {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
  93. static_cast<float*>(params.softmax_lse_log2_ptr),
  94. {seqlen_q_rounded, params.h, batch_q}, // shape_LSE
  95. {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
  96. static_cast<float*>(params.dsoftmax_sum),
  97. {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
  98. params.scale_softmax,
  99. params.window_size_left, params.window_size_right, params.sink_token_length,
  100. params.softcap,
  101. params.b,
  102. params.dq_semaphore,
  103. params.cu_seqlens_q, params.cu_seqlens_k,
  104. params.seqused_q, params.seqused_k
  105. };
  106. // The case work with GQA is ugly but idk how to fix it.
  107. typename CollectiveEpilogue::Arguments epilogue_args {
  108. static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dk_ptr : params.dk_accum_ptr),
  109. [&] {
  110. if constexpr (!GQA) {
  111. return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK
  112. } else {
  113. return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum
  114. }
  115. }(),
  116. [&] {
  117. if constexpr (!GQA) {
  118. return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK
  119. } else {
  120. return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum
  121. }
  122. }(),
  123. static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dv_ptr : params.dv_accum_ptr),
  124. [&] {
  125. if constexpr (!GQA) {
  126. return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV
  127. } else {
  128. return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum
  129. }
  130. }(),
  131. params.h,
  132. params.dk_semaphore,
  133. params.dv_semaphore,
  134. params.cu_seqlens_k,
  135. params.seqused_k,
  136. };
  137. int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{}));
  138. num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{}));
  139. typename flash::TileSchedulerArguments scheduler_args {
  140. num_blocks_n, params.h, params.b, 1 /*num_splits*/,
  141. params.h / params.h_k,
  142. params.seqlen_k,
  143. params.seqlen_q, params.d, sizeof(Element),
  144. params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k
  145. };
  146. int device;
  147. cudaGetDevice(&device);
  148. typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
  149. mainloop_args, epilogue_args, {device}, scheduler_args
  150. });
  151. dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
  152. dim3 block_dims = AttnKernel::get_block_shape();
  153. int smem_size = AttnKernel::SharedStorageSize;
  154. // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
  155. // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do));
  156. // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds));
  157. // int smem_size_dqacc = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc));
  158. // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
  159. // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
  160. // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse));
  161. // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum));
  162. // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum);
  163. if constexpr (size(ClusterShape{}) > 1) {
  164. void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
  165. if (smem_size >= 48 * 1024) {
  166. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  167. }
  168. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  169. cutlass::ClusterLauncher::launch(
  170. grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/);
  171. } else {
  172. if (smem_size >= 48 * 1024) {
  173. CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<AttnKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  174. }
  175. cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/);
  176. }
  177. CHECK_CUDA_KERNEL_LAUNCH();
  178. using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90,
  179. AttnKernel::CollectiveMainloop::NumMmaThreads,
  180. typename AttnKernel::CollectiveMainloop::TiledMmadQ,
  181. AttnKernel::CollectiveMainloop::dQ_swapAB
  182. >;
  183. typename PostprocessKernel::Arguments postprocess_args {
  184. static_cast<ElementAccum const*>(params.dq_accum_ptr),
  185. {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
  186. {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
  187. static_cast<Element*>(params.dq_ptr),
  188. {seqlen_q, params.d, params.h, batch_q}, // shape_dQ
  189. {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
  190. params.scale_softmax,
  191. params.cu_seqlens_q,
  192. params.seqused_q
  193. };
  194. typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
  195. int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
  196. dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b);
  197. int smem_size_postprocess = PostprocessKernel::SharedStorageSize;
  198. if (smem_size_postprocess >= 48 * 1024) {
  199. CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  200. }
  201. cutlass::kernel_launch<PostprocessKernel>(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/);
  202. CHECK_CUDA_KERNEL_LAUNCH();
  203. if constexpr (GQA) {
  204. using TileShape_NK = cute::Shape<Int<kBlockN>, Int<kHeadDim>>;
  205. using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_NK, Element, ElementAccum, cutlass::arch::Sm90,
  206. AttnKernel::CollectiveEpilogue::NumEpilogueThreads,
  207. typename AttnKernel::CollectiveMainloop::TiledMmadKV,
  208. AttnKernel::CollectiveMainloop::dKV_swapAB
  209. >;
  210. typename PostprocessKerneldKV::Arguments postprocess_dK_args {
  211. static_cast<ElementAccum const*>(params.dk_accum_ptr),
  212. {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum
  213. {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum
  214. static_cast<Element*>(params.dk_ptr),
  215. {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK
  216. {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK
  217. 1.f,
  218. params.cu_seqlens_k,
  219. params.seqused_k
  220. };
  221. typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args);
  222. typename PostprocessKerneldKV::Arguments postprocess_dV_args {
  223. static_cast<ElementAccum const*>(params.dv_accum_ptr),
  224. {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dVaccum
  225. {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum
  226. static_cast<Element*>(params.dv_ptr),
  227. {seqlen_k, params.d, params.h_k, batch_k}, // shape_dV
  228. {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV
  229. 1.f,
  230. params.cu_seqlens_k,
  231. params.seqused_k
  232. };
  233. typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args);
  234. int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{}));
  235. dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b);
  236. int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize;
  237. if (smem_size_postprocess >= 48 * 1024) {
  238. CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKerneldKV>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  239. }
  240. cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/);
  241. CHECK_CUDA_KERNEL_LAUNCH();
  242. cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/);
  243. CHECK_CUDA_KERNEL_LAUNCH();
  244. }
  245. }
  246. template<typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Has_softcap,
  247. int Stages_dO=2, int Stages_dS=2,
  248. bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
  249. int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
  250. void run_mha_bwd_dispatch(Flash_bwd_params &params, cudaStream_t stream) {
  251. VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
  252. BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
  253. // BOOL_SWITCH(params.deterministic, Deterministic, [&] {
  254. // run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
  255. run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Stages_dO, Stages_dS, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
  256. // });
  257. });
  258. });
  259. }
  260. template<typename T>
  261. void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
  262. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  263. SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
  264. // run_mha_bwd_dispatch<T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2>(params, stream);
  265. if constexpr (Is_causal && Has_softcap) {
  266. // register spill with 128 x 128
  267. run_mha_bwd_dispatch<T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 2>(params, stream);
  268. } else {
  269. // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
  270. run_mha_bwd_dispatch<T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2>(params, stream);
  271. }
  272. });
  273. });
  274. }
  275. template<typename T>
  276. void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
  277. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  278. SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
  279. run_mha_bwd_dispatch<T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1>(params, stream);
  280. });
  281. });
  282. }
  283. template<typename T>
  284. void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
  285. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  286. SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
  287. if constexpr (Is_causal || Is_local || Has_softcap) {
  288. run_mha_bwd_dispatch<T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1>(params, stream);
  289. } else {
  290. run_mha_bwd_dispatch<T, 80, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 1>(params, stream);
  291. }
  292. });
  293. });
  294. }
  295. template<typename T>
  296. void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
  297. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  298. SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
  299. run_mha_bwd_dispatch<T, 64, 96, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, true, false, 3, 1, 1, 1>(params, stream);
  300. });
  301. });
  302. }
  303. template<typename T>
  304. void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
  305. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  306. SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
  307. run_mha_bwd_dispatch<T, 64, 80, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, true, true, 2, 1, 1, 1>(params, stream);
  308. });
  309. });
  310. }