flash_fwd_launch_template.h 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cutlass.h"
  7. #include "cutlass/cluster_launch.hpp"
  8. #include "static_switch.h"
  9. #include "flash.h"
  10. #include "tile_scheduler.hpp"
  11. #include "flash_fwd_kernel.h"
  12. #include "kernel_traits.h"
  13. #include "seq_len.h"
  14. #include "utils.h"
  15. #include "combine.h"
  16. template<typename Kernel_traits, bool Is_causal, bool Is_local, typename Seqlen_traits, typename Seqlen_traits_Q = Seqlen_traits>
  17. void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  18. static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
  19. using Element = typename Kernel_traits::Element;
  20. using ElementAccum = typename Kernel_traits::ElementAccum;
  21. using OutputType = typename Kernel_traits::OutputType;
  22. using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
  23. using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
  24. constexpr static bool Is_split = Kernel_traits::Is_split;
  25. static_assert(Seqlen_traits_Q::UseGQAPacking == (Kernel_traits::kBlockH > 1), "If kBlockH > 1, use gqa packed layouts");
  26. static_assert(!(Is_split && Seqlen_traits::UseVarSeqLen), "Split KV not yet supported for variable seqlen.");
  27. using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Is_local, Seqlen_traits, Seqlen_traits_Q>;
  28. using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Seqlen_traits_Q>;
  29. using Scheduler = std::conditional_t<
  30. Seqlen_traits::UseVarSeqLen,
  31. flash::SingleTileScheduler,
  32. std::conditional_t<!Is_causal && !Is_local && !Is_split,
  33. flash::StaticPersistentTileScheduler<Is_split>,
  34. flash::DynamicPersistentTileScheduler<
  35. Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup,
  36. Kernel_traits::NumProducerThreads,
  37. Is_split
  38. >
  39. >>;
  40. // using Scheduler = flash::SingleTileScheduler;
  41. Seqlen_traits_Q seqlen_traits_q(
  42. params.total_q, params.seqlen_q, params.cu_seqlens_q, params.seqused_q);
  43. Seqlen_traits seqlen_traits_k(
  44. params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k);
  45. typename CollectiveMainloop::Params mainloop_params =
  46. CollectiveMainloop::to_underlying_arguments({
  47. static_cast<Element const*>(params.q_ptr),
  48. seqlen_traits_q.get_gmem_layout(
  49. params.seqlen_q, params.d, params.h_k, params.b, params.h_h_k_ratio,
  50. params.q_row_stride, params.q_head_stride, params.q_batch_stride
  51. ), // layout_Q
  52. static_cast<Element const*>(params.k_ptr),
  53. seqlen_traits_k.get_gmem_layout(
  54. params.seqlen_k, params.d, params.h_k, params.b_k,
  55. params.k_row_stride, params.k_head_stride, params.k_batch_stride,
  56. params.page_block_size, params.page_num_blocks
  57. ), // layout_K
  58. static_cast<Element const*>(params.v_ptr),
  59. seqlen_traits_k.get_gmem_layout(
  60. params.seqlen_k, params.d, params.h_k, params.b_k,
  61. params.v_row_stride, params.v_head_stride, params.v_batch_stride,
  62. params.page_block_size, params.page_num_blocks
  63. ), // layout_V
  64. seqlen_traits_k.get_virtual_shape(params.seqlen_k, params.d, params.h_k, params.b, params.h_h_k_ratio, false),
  65. params.scale_softmax_log2,
  66. params.descale_q_ptr,
  67. params.descale_k_ptr,
  68. params.descale_v_ptr,
  69. params.window_size_left,
  70. params.window_size_right,
  71. ceil_div(params.h_h_k_ratio, Kernel_traits::kBlockH),
  72. params.cache_batch_idx,
  73. Is_split ? params.num_splits : 1,
  74. params.block_table,
  75. params.block_table_batch_stride,
  76. params.page_block_size,
  77. (params.page_block_size > 0) ? params.b*params.seqlen_k/params.page_block_size : 0
  78. });
  79. typename CollectiveEpilogue::Params epilogue_params = [&] {
  80. if constexpr(!Is_split) {
  81. return CollectiveEpilogue::to_underlying_arguments({
  82. static_cast<OutputType*>(params.o_ptr),
  83. seqlen_traits_q.get_gmem_layout(
  84. params.seqlen_q, params.d, params.h_k, params.b, params.h_h_k_ratio,
  85. params.o_row_stride, params.o_head_stride, params.o_batch_stride
  86. ), // layout_O
  87. static_cast<float*>(params.softmax_lse_ptr),
  88. seqlen_traits_q.get_lse_gmem_layout(
  89. params.seqlen_q, params.h, params.b
  90. ) // layout_LSE
  91. });
  92. } else {
  93. return CollectiveEpilogue::to_underlying_arguments({
  94. static_cast<OutputType*>(params.oaccum_ptr),
  95. seqlen_traits_q.get_oaccum_gmem_layout(
  96. params.seqlen_q, params.d, params.h_k, params.b, params.h_h_k_ratio, params.num_splits,
  97. params.oaccum_row_stride, params.oaccum_head_stride, params.oaccum_batch_stride,
  98. params.oaccum_split_stride
  99. ), // layout_O
  100. static_cast<float*>(params.softmax_lseaccum_ptr),
  101. seqlen_traits_q.get_lseaccum_gmem_layout(
  102. params.seqlen_q, params.h, params.b, params.num_splits
  103. ), // layout_LSE
  104. });
  105. }
  106. }();
  107. int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM/Kernel_traits::kBlockH);
  108. num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
  109. int num_blocks_h = params.h_k * ceil_div(params.h_h_k_ratio, Kernel_traits::kBlockH);
  110. typename Scheduler::Arguments scheduler_args =
  111. {num_blocks_m, Is_split ? params.num_splits : 1, num_blocks_h, params.b, params.tile_count_semaphore};
  112. typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
  113. // Get the ptr to kernel function.
  114. void *kernel;
  115. if constexpr(cutlass::sizeof_bits_v<Element> == 8)
  116. kernel = (void *)flash::compute_attn_ws_fp8<Kernel_traits, Is_causal, Is_local, Scheduler, Seqlen_traits, Seqlen_traits_Q>;
  117. else
  118. kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Is_local, Scheduler, Seqlen_traits, Seqlen_traits_Q>;
  119. if (params.block_table != nullptr) {
  120. if ((params.page_block_size % Kernel_traits::kBlockN) != 0) {
  121. fprintf(stderr, "Sequence length in N (%d) dimension must divide page block size (%d) if block table is used\n", (int) Kernel_traits::kBlockN, (int) params.page_block_size);
  122. exit(1);
  123. }
  124. }
  125. int smem_size = sizeof(typename Kernel_traits::SharedStorage);
  126. // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
  127. // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
  128. // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
  129. // int smem_size_o = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_o));
  130. // printf("smem_size = %d, q = %d, k = %d, v = %d, o = %d.\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_o);
  131. if (smem_size >= 48 * 1024) {
  132. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  133. }
  134. int device;
  135. cudaGetDevice(&device);
  136. int multiprocessor_count;
  137. CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
  138. dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
  139. static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
  140. dim3 block_dims(ctaSize);
  141. if constexpr(size(ClusterShape{}) > 1) {
  142. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  143. cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
  144. cutlass::launch_kernel_on_cluster(
  145. launch_params, kernel, mainloop_params, epilogue_params,
  146. scheduler_params, seqlen_traits_q, seqlen_traits_k);
  147. } else {
  148. if constexpr(cutlass::sizeof_bits_v<Element> == 8) {
  149. flash::compute_attn_ws_fp8<Kernel_traits, Is_causal, Is_local, Scheduler, Seqlen_traits, Seqlen_traits_Q>
  150. <<<grid_dims, block_dims, smem_size, stream>>>
  151. (mainloop_params, epilogue_params, scheduler_params, seqlen_traits_q, seqlen_traits_k);
  152. } else {
  153. flash::compute_attn_ws<Kernel_traits, Is_causal, Is_local, Scheduler, Seqlen_traits, Seqlen_traits_Q>
  154. <<<grid_dims, block_dims, smem_size, stream>>>
  155. (mainloop_params, epilogue_params, scheduler_params, seqlen_traits_q, seqlen_traits_k);
  156. }
  157. }
  158. CHECK_CUDA_KERNEL_LAUNCH();
  159. if constexpr (Is_split) {
  160. using FinalOutputType = typename Kernel_traits::FinalOutputType;
  161. static_assert(is_same_v<OutputType, float>, "Assume OutputType of main kernel is float.");
  162. static_assert(is_same_v<ElementAccum, float>, "ElementAccum must be float.");
  163. // We want kBlockM to be as small as possible for more parallelism.
  164. // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
  165. // If headdim is divisible by 64, then we set kBlockM = 8, etc.
  166. constexpr static int kHeadDim = Kernel_traits::kHeadDim;
  167. constexpr static int kBlockM = kHeadDim % 128 == 0 ? 4 : (kHeadDim % 64 == 0 ? 8 : 16);
  168. constexpr static bool Is_even_K = true; // always true for our current setting
  169. void *kernel_combine;
  170. int smem_size_combine;
  171. NUM_SPLITS_SWITCH(params.num_splits, kLogMaxSplits, [&] {
  172. constexpr static int kMaxSplits = 1 << kLogMaxSplits;
  173. kernel_combine = (void *) flash::combine_attn_seqk_parallel<
  174. FinalOutputType, ElementAccum, kHeadDim, kBlockM, kLogMaxSplits, Is_even_K, Flash_fwd_params>;
  175. smem_size_combine = sizeof(
  176. flash::SharedStorageLSE<float, Shape<Int<kMaxSplits>, Int<kBlockM+1>>, Shape<Int<kMaxSplits>>>);
  177. });
  178. if (smem_size_combine >= 48 * 1024) {
  179. CHECK_CUDA(cudaFuncSetAttribute(kernel_combine, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_combine));
  180. }
  181. dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
  182. dim3 block_dims_combine(128);
  183. dim3 cluster_dims_combine(1, 1, 1);
  184. cutlass::ClusterLaunchParams launch_params_combine{
  185. grid_combine, block_dims_combine, cluster_dims_combine, smem_size_combine, stream};
  186. cutlass::launch_kernel_on_cluster(launch_params_combine, kernel_combine, params);
  187. CHECK_CUDA_KERNEL_LAUNCH();
  188. }
  189. }
  190. template<typename T>
  191. void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
  192. constexpr static int Headdim = 64;
  193. constexpr static bool UseCluster = false;
  194. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  195. BOOL_SWITCH(params.is_local, Is_local, [&] {
  196. MMA_3WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] {
  197. SEQLEN_SWITCH(params, Seqlen_traits, Seqlen_traits_Q, [&] {
  198. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  199. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  200. // && kNumMmaWGs == 3 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
  201. run_flash_fwd<
  202. Flash_fwd_kernel_traits<Headdim, kNumMmaWGs * 64, 128, 4 + kNumMmaWGs * 4,
  203. 2, false, UseCluster ? 2 : 1, T, !Seqlen_traits::UseVarSeqLen && Is_split>,
  204. Is_causal,
  205. Is_local && !Is_causal,
  206. Seqlen_traits,
  207. Seqlen_traits_Q
  208. >(params, stream);
  209. // });
  210. });
  211. });
  212. });
  213. });
  214. });
  215. }
  216. template<typename T>
  217. void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
  218. constexpr static int Headdim = 128;
  219. BOOL_SWITCH(params.block_table!=nullptr, UseBlockTable, [&] {
  220. MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] {
  221. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  222. BOOL_SWITCH(params.is_local, Is_local, [&] {
  223. SEQLEN_SWITCH(params, Seqlen_traits, Seqlen_traits_Q, [&] {
  224. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  225. // Only use Cluster if number of tiles along seqlen_q is even
  226. // and not Is_causal, Is_split, or varseqlen
  227. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  228. && kNumMmaWGs == 2 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
  229. run_flash_fwd<
  230. Flash_fwd_kernel_traits<Headdim, kNumMmaWGs * 64, (Is_causal || Is_local || UseBlockTable) ? 128 : 176,
  231. 4 + kNumMmaWGs * 4, 2, false, UseCluster ? 2 : 1,
  232. T, !Seqlen_traits::UseVarSeqLen && Is_split>,
  233. Is_causal,
  234. Is_local && !Is_causal,
  235. Seqlen_traits,
  236. Seqlen_traits_Q
  237. >(params, stream);
  238. });
  239. });
  240. });
  241. });
  242. });
  243. });
  244. });
  245. }
  246. template<typename T>
  247. void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
  248. constexpr static int Headdim = 256;
  249. BOOL_SWITCH(params.block_table!=nullptr, UseBlockTable, [&] {
  250. MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] {
  251. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  252. BOOL_SWITCH(params.is_local, Is_local, [&] {
  253. SEQLEN_SWITCH(params, Seqlen_traits, Seqlen_traits_Q, [&] {
  254. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  255. // Only use Cluster if number of tiles along seqlen_q is even
  256. // and not Is_causal, Is_split, or varseqlen
  257. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  258. && kNumMmaWGs == 2 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
  259. run_flash_fwd<
  260. Flash_fwd_kernel_traits<Headdim, kNumMmaWGs * 64, UseBlockTable ? 64 : (kNumMmaWGs == 1 ? 96 : 80),
  261. 4 + kNumMmaWGs * 4, 2, false, UseCluster ? 2 : 1,
  262. T, !Seqlen_traits::UseVarSeqLen && Is_split>,
  263. Is_causal,
  264. Is_local && !Is_causal,
  265. Seqlen_traits,
  266. Seqlen_traits_Q
  267. >(params, stream);
  268. });
  269. });
  270. });
  271. });
  272. });
  273. });
  274. });
  275. }
  276. template<typename T>
  277. void run_mha_fwd_hdim64_fp8(Flash_fwd_params &params, cudaStream_t stream) {
  278. constexpr static int Headdim = 64;
  279. constexpr static int kBlockN = 128;
  280. constexpr static int kStages = 4;
  281. // constexpr static bool UseCluster = false;
  282. // constexpr static int kBlockM = 192;
  283. // constexpr static int kNWarps = 4 + kBlockM/16;
  284. using Seqlen_traits = flash::FixedSeqLenTraits;
  285. MMA_3WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] {
  286. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  287. BOOL_SWITCH(params.is_local, Is_local, [&] {
  288. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  289. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  290. && kNumMmaWGs == 3, UseCluster, [&] {
  291. run_flash_fwd<
  292. Flash_fwd_kernel_traits_fp8<Headdim, kNumMmaWGs * 64, kBlockN, 4 + kNumMmaWGs * 4,
  293. kStages, false, UseCluster ? 2 : 1, T, Is_split>,
  294. Is_causal,
  295. Is_local && !Is_causal,
  296. Seqlen_traits
  297. >(params, stream);
  298. });
  299. });
  300. });
  301. });
  302. });
  303. }
  304. template<typename T>
  305. void run_mha_fwd_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) {
  306. constexpr static int Headdim = 128;
  307. constexpr static int kBlockN = 256;
  308. constexpr static int kStages = 2;
  309. // constexpr static int kBlockM = 128;
  310. // constexpr static int kNWarps = 4 + kBlockM/16;
  311. using Seqlen_traits = flash::FixedSeqLenTraits;
  312. MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] {
  313. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  314. BOOL_SWITCH(params.is_local, Is_local, [&] {
  315. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  316. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  317. && kNumMmaWGs == 2, UseCluster, [&] {
  318. run_flash_fwd<
  319. Flash_fwd_kernel_traits_fp8<Headdim, kNumMmaWGs * 64, kBlockN, 4 + kNumMmaWGs * 4,
  320. kStages, false, UseCluster ? 2 : 1, T, Is_split>,
  321. Is_causal,
  322. Is_local && !Is_causal,
  323. Seqlen_traits
  324. >(params, stream);
  325. });
  326. });
  327. });
  328. });
  329. });
  330. }
  331. template<typename T>
  332. void run_mha_fwd_hdim256_fp8(Flash_fwd_params &params, cudaStream_t stream) {
  333. constexpr static int Headdim = 256;
  334. constexpr static int kBlockN = 128;
  335. constexpr static int kStages = 2;
  336. // constexpr static int kBlockM = 128;
  337. // constexpr static int kNWarps = 4 + kBlockM/16;
  338. using Seqlen_traits = flash::FixedSeqLenTraits;
  339. MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] {
  340. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  341. BOOL_SWITCH(params.is_local, Is_local, [&] {
  342. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  343. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  344. && kNumMmaWGs == 2, UseCluster, [&] {
  345. run_flash_fwd<
  346. Flash_fwd_kernel_traits_fp8<Headdim, kNumMmaWGs * 64, kBlockN, 4 + kNumMmaWGs * 4,
  347. kStages, false, UseCluster ? 2 : 1, T, Is_split>,
  348. Is_causal,
  349. Is_local && !Is_causal,
  350. Seqlen_traits
  351. >(params, stream);
  352. });
  353. });
  354. });
  355. });
  356. });
  357. }
  358. /*
  359. ** GQA methods
  360. */
  361. template<typename T, int kBlockH>
  362. void run_mha_fwd_hdim64_gqa(Flash_fwd_params &params, cudaStream_t stream) {
  363. constexpr static int Headdim = 64;
  364. constexpr static bool UseCluster = false;
  365. using Seqlen_traits = flash::FixedSeqLenTraits;
  366. using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
  367. MMA_3WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {
  368. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  369. BOOL_SWITCH(params.is_local, Is_local, [&] {
  370. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  371. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  372. // && kNumMmaWGs == 3, UseCluster, [&] {
  373. run_flash_fwd<
  374. Flash_fwd_kernel_traits<Headdim, kNumMmaWGs * 64, 128, 4 + kNumMmaWGs * 4,
  375. 2, false, UseCluster ? 2 : 1, T, !Seqlen_traits::UseVarSeqLen && Is_split, kBlockH>,
  376. Is_causal,
  377. Is_local && !Is_causal,
  378. Seqlen_traits,
  379. Seqlen_traits_Q
  380. >(params, stream);
  381. // });
  382. });
  383. });
  384. });
  385. });
  386. }
  387. template<typename T, int kBlockH>
  388. void run_mha_fwd_hdim128_gqa(Flash_fwd_params &params, cudaStream_t stream) {
  389. constexpr static int Headdim = 128;
  390. constexpr static bool UseCluster = false;
  391. using Seqlen_traits = flash::FixedSeqLenTraits;
  392. using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
  393. MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {
  394. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  395. BOOL_SWITCH(params.is_local, Is_local, [&] {
  396. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  397. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  398. // && kNumMmaWGs == 2, UseCluster, [&] {
  399. run_flash_fwd<
  400. Flash_fwd_kernel_traits<Headdim, kNumMmaWGs * 64, (Is_causal || Is_local) ? 128 : 176,
  401. 4 + kNumMmaWGs * 4, 2, false, UseCluster ? 2 : 1, T, Is_split, kBlockH>,
  402. Is_causal,
  403. Is_local && !Is_causal,
  404. Seqlen_traits,
  405. Seqlen_traits_Q
  406. >(params, stream);
  407. // });
  408. });
  409. });
  410. });
  411. });
  412. }
  413. template<typename T, int kBlockH>
  414. void run_mha_fwd_hdim256_gqa(Flash_fwd_params &params, cudaStream_t stream) {
  415. constexpr static int Headdim = 256;
  416. constexpr static bool UseCluster = false;
  417. using Seqlen_traits = flash::FixedSeqLenTraits;
  418. using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
  419. MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {
  420. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  421. BOOL_SWITCH(params.is_local, Is_local, [&] {
  422. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  423. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  424. // && kNumMmaWGs == 2, UseCluster, [&] {
  425. run_flash_fwd<
  426. Flash_fwd_kernel_traits<Headdim, kNumMmaWGs * 64, kNumMmaWGs == 1 ? 96 : 80,
  427. 4 + kNumMmaWGs * 4, 2, false, UseCluster ? 2 : 1, T, Is_split, kBlockH>,
  428. Is_causal,
  429. Is_local && !Is_causal,
  430. Seqlen_traits,
  431. Seqlen_traits_Q
  432. >(params, stream);
  433. // });
  434. });
  435. });
  436. });
  437. });
  438. }
  439. template<typename T, int kBlockH>
  440. void run_mha_fwd_hdim64_fp8_gqa(Flash_fwd_params &params, cudaStream_t stream) {
  441. constexpr static int Headdim = 64;
  442. constexpr static int kBlockN = 128;
  443. constexpr static int kStages = 4;
  444. constexpr static bool UseCluster = false;
  445. using Seqlen_traits = flash::FixedSeqLenTraits;
  446. using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
  447. MMA_3WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {
  448. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  449. BOOL_SWITCH(params.is_local, Is_local, [&] {
  450. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  451. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  452. // && kNumMmaWGs == 3, UseCluster, [&] {
  453. run_flash_fwd<
  454. Flash_fwd_kernel_traits_fp8<Headdim, kNumMmaWGs * 64, kBlockN, 4 + kNumMmaWGs * 4,
  455. kStages, false, UseCluster ? 2 : 1, T, Is_split, kBlockH>,
  456. Is_causal,
  457. Is_local && !Is_causal,
  458. Seqlen_traits,
  459. Seqlen_traits_Q
  460. >(params, stream);
  461. // });
  462. });
  463. });
  464. });
  465. });
  466. }
  467. template<typename T, int kBlockH>
  468. void run_mha_fwd_hdim128_fp8_gqa(Flash_fwd_params &params, cudaStream_t stream) {
  469. constexpr static int Headdim = 128;
  470. constexpr static int kBlockN = 256;
  471. constexpr static int kStages = 2;
  472. constexpr static bool UseCluster = false;
  473. using Seqlen_traits = flash::FixedSeqLenTraits;
  474. using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
  475. MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {
  476. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  477. BOOL_SWITCH(params.is_local, Is_local, [&] {
  478. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  479. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  480. // && kNumMmaWGs == 2, UseCluster, [&] {
  481. run_flash_fwd<
  482. Flash_fwd_kernel_traits_fp8<Headdim, kNumMmaWGs * 64, kBlockN, 4 + kNumMmaWGs * 4,
  483. kStages, false, UseCluster ? 2 : 1, T, Is_split, kBlockH>,
  484. Is_causal,
  485. Is_local && !Is_causal,
  486. Seqlen_traits,
  487. Seqlen_traits_Q
  488. >(params, stream);
  489. // });
  490. });
  491. });
  492. });
  493. });
  494. }
  495. template<typename T, int kBlockH>
  496. void run_mha_fwd_hdim256_fp8_gqa(Flash_fwd_params &params, cudaStream_t stream) {
  497. constexpr static int Headdim = 256;
  498. constexpr static int kBlockN = 128;
  499. constexpr static int kStages = 2;
  500. constexpr static bool UseCluster = false;
  501. using Seqlen_traits = flash::FixedSeqLenTraits;
  502. using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
  503. MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {
  504. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  505. BOOL_SWITCH(params.is_local, Is_local, [&] {
  506. BOOL_SWITCH(params.num_splits > 1, Is_split, [&] {
  507. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split
  508. // && kNumMmaWGs == 2, UseCluster, [&] {
  509. run_flash_fwd<
  510. Flash_fwd_kernel_traits_fp8<Headdim, kNumMmaWGs * 64, kBlockN, 4 + kNumMmaWGs * 4,
  511. kStages, false, UseCluster ? 2 : 1, T, Is_split, kBlockH>,
  512. Is_causal,
  513. Is_local && !Is_causal,
  514. Seqlen_traits,
  515. Seqlen_traits_Q
  516. >(params, stream);
  517. // });
  518. });
  519. });
  520. });
  521. });
  522. }