flash_fwd_launch_template.h 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <ATen/cuda/CUDAContext.h>
  6. #include "static_switch.h"
  7. #include "flash.h"
  8. #include "flash_fwd_kernel.h"
  9. // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
  10. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  11. #define ARCH_SUPPORTS_FLASH
  12. #define KERNEL_PARAM_MODIFIER __grid_constant__
  13. #else
  14. #define KERNEL_PARAM_MODIFIER
  15. #endif
  16. // Define a macro for unsupported architecture handling to centralize the error message
  17. #define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
  18. // Use a macro to clean up kernel definitions
  19. #define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
  20. template<typename Kernel_traits, __VA_ARGS__> \
  21. __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
  22. DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
  23. #if defined(ARCH_SUPPORTS_FLASH)
  24. static_assert(!(Is_causal && Is_local)); // Enforce constraints
  25. flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
  26. #else
  27. FLASH_UNSUPPORTED_ARCH
  28. #endif
  29. }
  30. DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) {
  31. #if defined(ARCH_SUPPORTS_FLASH)
  32. flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params);
  33. #else
  34. FLASH_UNSUPPORTED_ARCH
  35. #endif
  36. }
  37. DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
  38. static_assert(Log_max_splits >= 1);
  39. flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
  40. }
  41. template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
  42. void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  43. constexpr size_t smem_size = Kernel_traits::kSmemSize;
  44. // printf("smem_size = %d\n", smem_size);
  45. // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
  46. // https://github.com/kokkos/kokkos-kernels/issues/349
  47. // https://github.com/HazyResearch/flash-attention/issues/21
  48. const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
  49. dim3 grid(num_m_block, params.b, params.h);
  50. const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
  51. const bool is_even_K = params.d == Kernel_traits::kHeadDim;
  52. const bool return_softmax = params.p_ptr != nullptr;
  53. BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
  54. EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
  55. LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
  56. BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
  57. ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
  58. SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
  59. // Will only return softmax if dropout, to reduce compilation time.
  60. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
  61. // If return_softmax, set IsEvenMNConst to false to reduce number of templates
  62. // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
  63. // If Is_local, set Is_causal to false
  64. auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
  65. // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
  66. // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
  67. // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
  68. if (smem_size >= 48 * 1024) {
  69. C10_CUDA_CHECK(cudaFuncSetAttribute(
  70. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  71. }
  72. // int ctas_per_sm;
  73. // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
  74. // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
  75. // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
  76. kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
  77. C10_CUDA_KERNEL_LAUNCH_CHECK();
  78. });
  79. });
  80. });
  81. });
  82. });
  83. });
  84. }
  85. template<typename Kernel_traits, bool Is_causal>
  86. void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  87. static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
  88. static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
  89. constexpr size_t smem_size = Kernel_traits::kSmemSize;
  90. const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
  91. dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
  92. const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
  93. const bool is_even_K = params.d == Kernel_traits::kHeadDim;
  94. BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
  95. EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
  96. LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
  97. BOOL_SWITCH(params.num_splits > 1, Split, [&] {
  98. BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
  99. ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
  100. SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
  101. // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
  102. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
  103. // If Is_local, set Is_causal to false
  104. auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
  105. // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
  106. // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
  107. if (smem_size >= 48 * 1024) {
  108. C10_CUDA_CHECK(cudaFuncSetAttribute(
  109. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  110. }
  111. kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
  112. C10_CUDA_KERNEL_LAUNCH_CHECK();
  113. });
  114. });
  115. });
  116. });
  117. });
  118. });
  119. });
  120. if (params.num_splits > 1) {
  121. // We want kBlockM to be as small as possible for more parallelism.
  122. // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
  123. // If headdim is divisible by 64, then we set kBlockM = 8, etc.
  124. constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
  125. dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
  126. EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
  127. if (params.num_splits <= 2) {
  128. flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
  129. } else if (params.num_splits <= 4) {
  130. flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
  131. } else if (params.num_splits <= 8) {
  132. flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
  133. } else if (params.num_splits <= 16) {
  134. flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
  135. } else if (params.num_splits <= 32) {
  136. flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
  137. } else if (params.num_splits <= 64) {
  138. flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
  139. } else if (params.num_splits <= 128) {
  140. flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
  141. }
  142. C10_CUDA_KERNEL_LAUNCH_CHECK();
  143. });
  144. }
  145. }
  146. template<typename T, int Headdim, bool Is_causal>
  147. void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
  148. constexpr static int kBlockM = 64; // Fixed for all head dimensions
  149. // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
  150. // and for headdim 192 with block size 64 x 128.
  151. // Also for headdim 160 with block size 64 x 128 after the rotary addition.
  152. constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
  153. run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
  154. }
  155. template<typename T, bool Is_causal>
  156. void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
  157. constexpr static int Headdim = 32;
  158. DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
  159. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  160. });
  161. }
  162. template<typename T, bool Is_causal>
  163. void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
  164. constexpr static int Headdim = 64;
  165. DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
  166. if constexpr(!Is_dropout) {
  167. // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
  168. // Using block size (64 x 256) is 27% slower for seqlen=2k
  169. // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
  170. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  171. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
  172. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
  173. } else {
  174. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  175. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
  176. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
  177. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  178. }
  179. });
  180. }
  181. template<typename T, bool Is_causal>
  182. void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
  183. constexpr static int Headdim = 96;
  184. auto dprops = at::cuda::getCurrentDeviceProperties();
  185. bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
  186. DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
  187. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
  188. if (is_sm8x) {
  189. if constexpr(!Is_causal) {
  190. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  191. } else {
  192. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  193. }
  194. } else {
  195. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  196. }
  197. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
  198. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
  199. // These two are always slower
  200. // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
  201. // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
  202. });
  203. }
  204. template<typename T, bool Is_causal>
  205. void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
  206. constexpr static int Headdim = 128;
  207. auto dprops = at::cuda::getCurrentDeviceProperties();
  208. bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
  209. DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
  210. if constexpr(!Is_dropout) {
  211. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
  212. // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
  213. if (is_sm8x) {
  214. if constexpr(!Is_causal) {
  215. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  216. } else {
  217. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  218. }
  219. } else {
  220. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  221. }
  222. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
  223. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
  224. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  225. // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
  226. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
  227. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
  228. // 1st ones are good for H100, A100
  229. // 2nd one is good for A6000 bc we get slightly better occupancy
  230. } else {
  231. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  232. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  233. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
  234. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
  235. }
  236. });
  237. }
  238. template<typename T, bool Is_causal>
  239. void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
  240. constexpr static int Headdim = 160;
  241. auto dprops = at::cuda::getCurrentDeviceProperties();
  242. bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
  243. DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
  244. // For A100, H100, 128 x 32 is the fastest.
  245. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
  246. // and 128 x 64 with 8 warps is the fastest for non-causal.
  247. if (is_sm8x) {
  248. if constexpr(!Is_causal) {
  249. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
  250. } else {
  251. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  252. }
  253. } else {
  254. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  255. }
  256. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
  257. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  258. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
  259. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
  260. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
  261. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
  262. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
  263. });
  264. }
  265. template<typename T, bool Is_causal>
  266. void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
  267. constexpr static int Headdim = 192;
  268. DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
  269. if constexpr(!Is_dropout) {
  270. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
  271. } else {
  272. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  273. }
  274. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  275. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
  276. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
  277. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
  278. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
  279. });
  280. }
  281. template<typename T, bool Is_causal>
  282. void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
  283. constexpr static int Headdim = 224;
  284. int device;
  285. cudaGetDevice(&device);
  286. int max_smem_per_block;
  287. cudaError status_ = cudaDeviceGetAttribute(
  288. &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
  289. if (status_ != cudaSuccess) {
  290. C10_CUDA_CHECK(status_);
  291. }
  292. // printf("max_smem_per_block = %d\n", max_smem_per_block);
  293. DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
  294. if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
  295. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
  296. } else {
  297. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  298. }
  299. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  300. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  301. // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
  302. // If we have N = 32, there are only 1024 elements to load at once, where each load
  303. // is 8 elements. This means we can only use 128 threads and not 256 threads.
  304. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
  305. });
  306. }
  307. template<typename T, bool Is_causal>
  308. void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
  309. constexpr static int Headdim = 256;
  310. int device;
  311. cudaGetDevice(&device);
  312. int max_smem_per_sm, max_smem_per_block;
  313. cudaError status_ = cudaDeviceGetAttribute(
  314. &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
  315. status_ = cudaDeviceGetAttribute(
  316. &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
  317. if (status_ != cudaSuccess) {
  318. C10_CUDA_CHECK(status_);
  319. }
  320. // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
  321. DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
  322. // For A100, we want to run with 128 x 64 (128KB smem).
  323. // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
  324. if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
  325. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
  326. } else {
  327. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  328. }
  329. // 64 KB
  330. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
  331. // 96 KB
  332. // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
  333. });
  334. }