flash_fwd_launch_template.h 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cutlass.h"
  7. #include "cutlass/cluster_launch.hpp"
  8. #include "static_switch.h"
  9. #include "flash.h"
  10. #include "tile_scheduler.hpp"
  11. #include "flash_fwd_kernel.h"
  12. #include "kernel_traits.h"
  13. #include "seq_len.h"
  14. #include "utils.h"
  15. #include "combine.h"
  16. template<typename Kernel_traits, bool Is_causal, bool Is_local, typename Seqlen_traits, typename Seqlen_traits_Q = Seqlen_traits>
  17. void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  18. static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
  19. using Element = typename Kernel_traits::Element;
  20. using ElementAccum = typename Kernel_traits::ElementAccum;
  21. using OutputType = typename Kernel_traits::OutputType;
  22. using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
  23. using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
  24. constexpr static bool Is_split = Kernel_traits::Is_split;
  25. static_assert(Seqlen_traits_Q::UseGQAPacking == (Kernel_traits::kBlockH > 1), "If kBlockH > 1, use gqa packed layouts");
  26. static_assert(!(Is_split && Seqlen_traits::UseVarSeqLen), "Split KV not yet supported for variable seqlen.");
  27. using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Is_local, Seqlen_traits, Seqlen_traits_Q>;
  28. using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Seqlen_traits_Q>;
  29. using Scheduler = std::conditional_t<
  30. Seqlen_traits::UseVarSeqLen || Seqlen_traits_Q::UseVarSeqLen,
  31. flash::SingleTileScheduler,
  32. std::conditional_t<!Is_causal && !Is_local && !Is_split,
  33. flash::StaticPersistentTileScheduler<Is_split>,
  34. flash::DynamicPersistentTileScheduler<
  35. Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup,
  36. Kernel_traits::NumProducerThreads,
  37. Is_split
  38. >
  39. >>;
  40. // using Scheduler = flash::SingleTileScheduler;
  41. // using Scheduler = std::conditional_t<
  42. // !Is_causal && !Is_local && !Is_split,
  43. // flash::StaticPersistentTileScheduler<Is_split>,
  44. // flash::DynamicPersistentTileScheduler<
  45. // Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup,
  46. // Kernel_traits::NumProducerThreads,
  47. // Is_split
  48. // >
  49. // >;
  50. // using Scheduler = flash::DynamicPersistentTileScheduler<
  51. // Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup,
  52. // Kernel_traits::NumProducerThreads,
  53. // Is_split>;
  54. Seqlen_traits_Q seqlen_traits_q(
  55. params.total_q, params.seqlen_q, params.cu_seqlens_q, params.seqused_q);
  56. Seqlen_traits seqlen_traits_k(
  57. params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k);
  58. typename CollectiveMainloop::Params mainloop_params =
  59. CollectiveMainloop::to_underlying_arguments({
  60. static_cast<Element const*>(params.q_ptr),
  61. seqlen_traits_q.get_gmem_layout(
  62. params.seqlen_q, params.d, params.h_k, params.b, params.h_h_k_ratio,
  63. params.q_row_stride, params.q_head_stride, params.q_batch_stride
  64. ), // layout_Q
  65. static_cast<Element const*>(params.k_ptr),
  66. seqlen_traits_k.get_gmem_layout(
  67. params.seqlen_k, params.d, params.h_k, params.b_k,
  68. params.k_row_stride, params.k_head_stride, params.k_batch_stride
  69. ), // layout_K
  70. static_cast<Element const*>(params.v_ptr),
  71. seqlen_traits_k.get_gmem_layout(
  72. params.seqlen_k, params.d, params.h_k, params.b_k,
  73. params.v_row_stride, params.v_head_stride, params.v_batch_stride
  74. ), // layout_V
  75. params.scale_softmax_log2,
  76. params.descale_q_ptr,
  77. params.descale_k_ptr,
  78. params.descale_v_ptr,
  79. params.window_size_left,
  80. params.window_size_right,
  81. ceil_div(params.h_h_k_ratio, Kernel_traits::kBlockH),
  82. params.cache_batch_idx,
  83. Is_split ? params.num_splits : 1
  84. });
  85. typename CollectiveEpilogue::Params epilogue_params = [&] {
  86. if constexpr(!Is_split) {
  87. return CollectiveEpilogue::to_underlying_arguments({
  88. static_cast<OutputType*>(params.o_ptr),
  89. seqlen_traits_q.get_gmem_layout(
  90. params.seqlen_q, params.d, params.h_k, params.b, params.h_h_k_ratio,
  91. params.o_row_stride, params.o_head_stride, params.o_batch_stride
  92. ), // layout_O
  93. static_cast<float*>(params.softmax_lse_ptr),
  94. seqlen_traits_q.get_lse_gmem_layout(
  95. params.seqlen_q, params.h, params.b
  96. ) // layout_LSE
  97. });
  98. } else {
  99. return CollectiveEpilogue::to_underlying_arguments({
  100. static_cast<OutputType*>(params.oaccum_ptr),
  101. seqlen_traits_q.get_oaccum_gmem_layout(
  102. params.seqlen_q, params.d, params.h_k, params.b, params.h_h_k_ratio, params.num_splits,
  103. params.oaccum_row_stride, params.oaccum_head_stride, params.oaccum_batch_stride,
  104. params.oaccum_split_stride
  105. ), // layout_O
  106. static_cast<float*>(params.softmax_lseaccum_ptr),
  107. seqlen_traits_q.get_lseaccum_gmem_layout(
  108. params.seqlen_q, params.h, params.b, params.num_splits
  109. ), // layout_LSE
  110. });
  111. }
  112. }();
  113. int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM/Kernel_traits::kBlockH);
  114. num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
  115. int num_blocks_h = params.h_k * ceil_div(params.h_h_k_ratio, Kernel_traits::kBlockH);
  116. typename Scheduler::Arguments scheduler_args =
  117. {num_blocks_m, Is_split ? params.num_splits : 1, num_blocks_h, params.b, params.tile_count_semaphore};
  118. typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
  119. // Get the ptr to kernel function.
  120. void *kernel;
  121. if constexpr(cutlass::sizeof_bits_v<Element> == 8)
  122. kernel = (void *)flash::compute_attn_ws_fp8<Kernel_traits, Is_causal, Is_local, Scheduler, Seqlen_traits, Seqlen_traits_Q>;
  123. else
  124. kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Is_local, Scheduler, Seqlen_traits, Seqlen_traits_Q>;
  125. int smem_size = sizeof(typename Kernel_traits::SharedStorage);
  126. // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
  127. // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
  128. // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
  129. // int smem_size_o = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_o));
  130. // printf("smem_size = %d, q = %d, k = %d, v = %d, o = %d.\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_o);
  131. if (smem_size >= 48 * 1024) {
  132. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  133. }
  134. int device;
  135. cudaGetDevice(&device);
  136. int multiprocessor_count;
  137. CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
  138. dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
  139. static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
  140. dim3 block_dims(ctaSize);
  141. if constexpr(size(ClusterShape{}) > 1) {
  142. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  143. cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
  144. cutlass::launch_kernel_on_cluster(
  145. launch_params, kernel, mainloop_params, epilogue_params,
  146. scheduler_params, seqlen_traits_q, seqlen_traits_k);
  147. } else {
  148. if constexpr(cutlass::sizeof_bits_v<Element> == 8) {
  149. flash::compute_attn_ws_fp8<Kernel_traits, Is_causal, Is_local, Scheduler, Seqlen_traits, Seqlen_traits_Q>
  150. <<<grid_dims, block_dims, smem_size, stream>>>
  151. (mainloop_params, epilogue_params, scheduler_params, seqlen_traits_q, seqlen_traits_k);
  152. } else {
  153. flash::compute_attn_ws<Kernel_traits, Is_causal, Is_local, Scheduler, Seqlen_traits, Seqlen_traits_Q>
  154. <<<grid_dims, block_dims, smem_size, stream>>>
  155. (mainloop_params, epilogue_params, scheduler_params, seqlen_traits_q, seqlen_traits_k);
  156. }
  157. }
  158. CHECK_CUDA_KERNEL_LAUNCH();
  159. if constexpr (Is_split) {
  160. using FinalOutputType = typename Kernel_traits::FinalOutputType;
  161. static_assert(is_same_v<OutputType, float>, "Assume OutputType of main kernel is float.");
  162. static_assert(is_same_v<ElementAccum, float>, "ElementAccum must be float.");
  163. // We want kBlockM to be as small as possible for more parallelism.
  164. // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
  165. // If headdim is divisible by 64, then we set kBlockM = 8, etc.
  166. constexpr static int kHeadDim = Kernel_traits::kHeadDim;
  167. constexpr static int kBlockM = kHeadDim % 128 == 0 ? 4 : (kHeadDim % 64 == 0 ? 8 : 16);
  168. constexpr static bool Is_even_K = true; // always true for our current setting
  169. void *kernel_combine;
  170. int smem_size_combine;
  171. NUM_SPLITS_SWITCH(params.num_splits, kLogMaxSplits, [&] {
  172. constexpr static int kMaxSplits = 1 << kLogMaxSplits;
  173. kernel_combine = (void *) flash::combine_attn_seqk_parallel<
  174. FinalOutputType, ElementAccum, kHeadDim, kBlockM, kLogMaxSplits, Is_even_K, Flash_fwd_params>;
  175. smem_size_combine = sizeof(
  176. flash::SharedStorageLSE<float, Shape<Int<kMaxSplits>, Int<kBlockM+1>>, Shape<Int<kMaxSplits>>>);
  177. });
  178. if (smem_size_combine >= 48 * 1024) {
  179. CHECK_CUDA(cudaFuncSetAttribute(kernel_combine, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_combine));
  180. }
  181. dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
  182. dim3 block_dims_combine(128);
  183. dim3 cluster_dims_combine(1, 1, 1);
  184. cutlass::ClusterLaunchParams launch_params_combine{
  185. grid_combine, block_dims_combine, cluster_dims_combine, smem_size_combine, stream};
  186. cutlass::launch_kernel_on_cluster(launch_params_combine, kernel_combine, params);
  187. CHECK_CUDA_KERNEL_LAUNCH();
  188. }
  189. }
  190. template<typename T>
  191. void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
  192. constexpr static int Headdim = 64;
  193. constexpr static bool UseCluster = false;
  194. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  195. BOOL_SWITCH(params.is_local, Is_local, [&] {
  196. MMA_3WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] {
  197. SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  198. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  199. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  200. // && kNumMmaWGs == 3 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
  201. run_flash_fwd<
  202. Flash_fwd_kernel_traits<Headdim, kNumMmaWGs * 64, 128, 4 + kNumMmaWGs * 4,
  203. 2, false, UseCluster ? 2 : 1, T, !Seqlen_traits::UseVarSeqLen && Is_split>,
  204. Is_causal,
  205. Is_local && !Is_causal,
  206. Seqlen_traits
  207. >(params, stream);
  208. // });
  209. });
  210. });
  211. });
  212. });
  213. });
  214. }
  215. template<typename T>
  216. void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
  217. constexpr static int Headdim = 128;
  218. MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] {
  219. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  220. BOOL_SWITCH(params.is_local, Is_local, [&] {
  221. SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  222. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  223. // Only use Cluster if number of tiles along seqlen_q is even
  224. // and not Is_causal, Is_split, or varseqlen
  225. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  226. && kNumMmaWGs == 2 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
  227. run_flash_fwd<
  228. Flash_fwd_kernel_traits<Headdim, kNumMmaWGs * 64, (Is_causal || Is_local) ? 128 : 176,
  229. 4 + kNumMmaWGs * 4, 2, false, UseCluster ? 2 : 1,
  230. T, !Seqlen_traits::UseVarSeqLen && Is_split>,
  231. Is_causal,
  232. Is_local && !Is_causal,
  233. Seqlen_traits
  234. >(params, stream);
  235. });
  236. });
  237. });
  238. });
  239. });
  240. });
  241. }
  242. template<typename T>
  243. void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
  244. constexpr static int Headdim = 256;
  245. MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] {
  246. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  247. BOOL_SWITCH(params.is_local, Is_local, [&] {
  248. SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  249. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  250. // Only use Cluster if number of tiles along seqlen_q is even
  251. // and not Is_causal, Is_split, or varseqlen
  252. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  253. && kNumMmaWGs == 2 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
  254. run_flash_fwd<
  255. Flash_fwd_kernel_traits<Headdim, kNumMmaWGs * 64, kNumMmaWGs == 1 ? 96 : 80,
  256. 4 + kNumMmaWGs * 4, 2, false, UseCluster ? 2 : 1,
  257. T, !Seqlen_traits::UseVarSeqLen && Is_split>,
  258. Is_causal,
  259. Is_local && !Is_causal,
  260. Seqlen_traits
  261. >(params, stream);
  262. });
  263. });
  264. });
  265. });
  266. });
  267. });
  268. }
  269. template<typename T>
  270. void run_mha_fwd_hdim64_fp8(Flash_fwd_params &params, cudaStream_t stream) {
  271. constexpr static int Headdim = 64;
  272. constexpr static int kBlockN = 128;
  273. constexpr static int kStages = 4;
  274. // constexpr static bool UseCluster = false;
  275. // constexpr static int kBlockM = 192;
  276. // constexpr static int kNWarps = 4 + kBlockM/16;
  277. using Seqlen_traits = flash::FixedSeqLenTraits;
  278. constexpr static bool UseCluster = false;
  279. MMA_3WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] {
  280. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  281. BOOL_SWITCH(params.is_local, Is_local, [&] {
  282. SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  283. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  284. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  285. // && kNumMmaWGs == 3 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
  286. run_flash_fwd<
  287. Flash_fwd_kernel_traits_fp8<Headdim, kNumMmaWGs * 64, kBlockN, 4 + kNumMmaWGs * 4,
  288. kStages, false, UseCluster ? 2 : 1, T, !Seqlen_traits::UseVarSeqLen && Is_split>,
  289. Is_causal,
  290. Is_local && !Is_causal,
  291. Seqlen_traits
  292. >(params, stream);
  293. // });
  294. });
  295. });
  296. });
  297. });
  298. });
  299. }
  300. template<typename T>
  301. void run_mha_fwd_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) {
  302. constexpr static int Headdim = 128;
  303. constexpr static int kBlockN = 256;
  304. // constexpr static int kBlockN = 192;
  305. constexpr static int kStages = 2;
  306. // constexpr static int kBlockM = 128;
  307. // constexpr static int kNWarps = 4 + kBlockM/16;
  308. using Seqlen_traits = flash::FixedSeqLenTraits;
  309. constexpr static bool UseCluster = false;
  310. // constexpr static bool Is_split = false;
  311. // constexpr static bool Is_local = false;
  312. // constexpr static int kNumMmaWGs = 2;
  313. MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] {
  314. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  315. BOOL_SWITCH(params.is_local, Is_local, [&] {
  316. SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  317. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  318. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  319. // && kNumMmaWGs == 2 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
  320. run_flash_fwd<
  321. Flash_fwd_kernel_traits_fp8<Headdim, kNumMmaWGs * 64, kBlockN, 4 + kNumMmaWGs * 4,
  322. kStages, false, UseCluster ? 2 : 1, T, !Seqlen_traits::UseVarSeqLen && Is_split>,
  323. Is_causal,
  324. Is_local && !Is_causal,
  325. Seqlen_traits
  326. >(params, stream);
  327. // });
  328. });
  329. });
  330. });
  331. });
  332. });
  333. }
  334. template<typename T>
  335. void run_mha_fwd_hdim256_fp8(Flash_fwd_params &params, cudaStream_t stream) {
  336. constexpr static int Headdim = 256;
  337. constexpr static int kBlockN = 128;
  338. constexpr static int kStages = 2;
  339. // constexpr static int kBlockM = 128;
  340. // constexpr static int kNWarps = 4 + kBlockM/16;
  341. using Seqlen_traits = flash::FixedSeqLenTraits;
  342. constexpr static bool UseCluster = false;
  343. MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] {
  344. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  345. BOOL_SWITCH(params.is_local, Is_local, [&] {
  346. SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  347. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  348. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  349. // && kNumMmaWGs == 2 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
  350. run_flash_fwd<
  351. Flash_fwd_kernel_traits_fp8<Headdim, kNumMmaWGs * 64, kBlockN, 4 + kNumMmaWGs * 4,
  352. kStages, false, UseCluster ? 2 : 1, T, !Seqlen_traits::UseVarSeqLen && Is_split>,
  353. Is_causal,
  354. Is_local && !Is_causal,
  355. Seqlen_traits
  356. >(params, stream);
  357. // });
  358. });
  359. });
  360. });
  361. });
  362. });
  363. }
  364. /*
  365. ** GQA methods
  366. */
  367. template<typename T, int kBlockH>
  368. void run_mha_fwd_hdim64_gqa(Flash_fwd_params &params, cudaStream_t stream) {
  369. constexpr static int Headdim = 64;
  370. constexpr static bool UseCluster = false;
  371. using Seqlen_traits = flash::FixedSeqLenTraits;
  372. using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
  373. MMA_3WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {
  374. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  375. BOOL_SWITCH(params.is_local, Is_local, [&] {
  376. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  377. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  378. // && kNumMmaWGs == 3, UseCluster, [&] {
  379. run_flash_fwd<
  380. Flash_fwd_kernel_traits<Headdim, kNumMmaWGs * 64, 128, 4 + kNumMmaWGs * 4,
  381. 2, false, UseCluster ? 2 : 1, T, !Seqlen_traits::UseVarSeqLen && Is_split, kBlockH>,
  382. Is_causal,
  383. Is_local && !Is_causal,
  384. Seqlen_traits,
  385. Seqlen_traits_Q
  386. >(params, stream);
  387. // });
  388. });
  389. });
  390. });
  391. });
  392. }
  393. template<typename T, int kBlockH>
  394. void run_mha_fwd_hdim128_gqa(Flash_fwd_params &params, cudaStream_t stream) {
  395. constexpr static int Headdim = 128;
  396. constexpr static bool UseCluster = false;
  397. using Seqlen_traits = flash::FixedSeqLenTraits;
  398. using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
  399. MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {
  400. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  401. BOOL_SWITCH(params.is_local, Is_local, [&] {
  402. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  403. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  404. // && kNumMmaWGs == 2, UseCluster, [&] {
  405. run_flash_fwd<
  406. Flash_fwd_kernel_traits<Headdim, kNumMmaWGs * 64, (Is_causal || Is_local) ? 128 : 176,
  407. 4 + kNumMmaWGs * 4, 2, false, UseCluster ? 2 : 1, T, Is_split, kBlockH>,
  408. Is_causal,
  409. Is_local && !Is_causal,
  410. Seqlen_traits,
  411. Seqlen_traits_Q
  412. >(params, stream);
  413. // });
  414. });
  415. });
  416. });
  417. });
  418. }
  419. template<typename T, int kBlockH>
  420. void run_mha_fwd_hdim256_gqa(Flash_fwd_params &params, cudaStream_t stream) {
  421. constexpr static int Headdim = 256;
  422. constexpr static bool UseCluster = false;
  423. using Seqlen_traits = flash::FixedSeqLenTraits;
  424. using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
  425. MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {
  426. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  427. BOOL_SWITCH(params.is_local, Is_local, [&] {
  428. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  429. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  430. // && kNumMmaWGs == 2, UseCluster, [&] {
  431. run_flash_fwd<
  432. Flash_fwd_kernel_traits<Headdim, kNumMmaWGs * 64, kNumMmaWGs == 1 ? 96 : 80,
  433. 4 + kNumMmaWGs * 4, 2, false, UseCluster ? 2 : 1, T, Is_split, kBlockH>,
  434. Is_causal,
  435. Is_local && !Is_causal,
  436. Seqlen_traits,
  437. Seqlen_traits_Q
  438. >(params, stream);
  439. // });
  440. });
  441. });
  442. });
  443. });
  444. }
  445. template<typename T, int kBlockH>
  446. void run_mha_fwd_hdim64_fp8_gqa(Flash_fwd_params &params, cudaStream_t stream) {
  447. constexpr static int Headdim = 64;
  448. constexpr static int kBlockN = 128;
  449. constexpr static int kStages = 4;
  450. constexpr static bool UseCluster = false;
  451. using Seqlen_traits = flash::FixedSeqLenTraits;
  452. using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
  453. MMA_3WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {
  454. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  455. BOOL_SWITCH(params.is_local, Is_local, [&] {
  456. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  457. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  458. // && kNumMmaWGs == 3, UseCluster, [&] {
  459. run_flash_fwd<
  460. Flash_fwd_kernel_traits_fp8<Headdim, kNumMmaWGs * 64, kBlockN, 4 + kNumMmaWGs * 4,
  461. kStages, false, UseCluster ? 2 : 1, T, Is_split, kBlockH>,
  462. Is_causal,
  463. Is_local && !Is_causal,
  464. Seqlen_traits,
  465. Seqlen_traits_Q
  466. >(params, stream);
  467. // });
  468. });
  469. });
  470. });
  471. });
  472. }
  473. template<typename T, int kBlockH>
  474. void run_mha_fwd_hdim128_fp8_gqa(Flash_fwd_params &params, cudaStream_t stream) {
  475. constexpr static int Headdim = 128;
  476. constexpr static int kBlockN = 256;
  477. constexpr static int kStages = 2;
  478. constexpr static bool UseCluster = false;
  479. using Seqlen_traits = flash::FixedSeqLenTraits;
  480. using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
  481. MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {
  482. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  483. BOOL_SWITCH(params.is_local, Is_local, [&] {
  484. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  485. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  486. // && kNumMmaWGs == 2, UseCluster, [&] {
  487. run_flash_fwd<
  488. Flash_fwd_kernel_traits_fp8<Headdim, kNumMmaWGs * 64, kBlockN, 4 + kNumMmaWGs * 4,
  489. kStages, false, UseCluster ? 2 : 1, T, Is_split, kBlockH>,
  490. Is_causal,
  491. Is_local && !Is_causal,
  492. Seqlen_traits,
  493. Seqlen_traits_Q
  494. >(params, stream);
  495. // });
  496. });
  497. });
  498. });
  499. });
  500. }
  501. template<typename T, int kBlockH>
  502. void run_mha_fwd_hdim256_fp8_gqa(Flash_fwd_params &params, cudaStream_t stream) {
  503. constexpr static int Headdim = 256;
  504. constexpr static int kBlockN = 128;
  505. constexpr static int kStages = 2;
  506. constexpr static bool UseCluster = false;
  507. using Seqlen_traits = flash::FixedSeqLenTraits;
  508. using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
  509. MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {
  510. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  511. BOOL_SWITCH(params.is_local, Is_local, [&] {
  512. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  513. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  514. // && kNumMmaWGs == 2, UseCluster, [&] {
  515. run_flash_fwd<
  516. Flash_fwd_kernel_traits_fp8<Headdim, kNumMmaWGs * 64, kBlockN, 4 + kNumMmaWGs * 4,
  517. kStages, false, UseCluster ? 2 : 1, T, Is_split, kBlockH>,
  518. Is_causal,
  519. Is_local && !Is_causal,
  520. Seqlen_traits,
  521. Seqlen_traits_Q
  522. >(params, stream);
  523. // });
  524. });
  525. });
  526. });
  527. });
  528. }