flash_fwd_launch_template.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cutlass.h"
  7. #include "cutlass/cluster_launch.hpp"
  8. #include "static_switch.h"
  9. #include "flash.h"
  10. #include "tile_scheduler.hpp"
  11. #include "flash_fwd_kernel.h"
  12. #include "kernel_traits.h"
  13. #include "seq_len.h"
  14. #include "utils.h"
  15. template<typename Kernel_traits, bool Is_causal, typename Seqlen_traits>
  16. void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  17. using Element = typename Kernel_traits::Element;
  18. using OutputType = typename Kernel_traits::OutputType;
  19. using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
  20. using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
  21. // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
  22. using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Seqlen_traits>;
  23. using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Seqlen_traits>;
  24. using Scheduler = std::conditional_t<
  25. Seqlen_traits::kUseVarSeqLen,
  26. flash::SingleTileScheduler,
  27. std::conditional_t<!Is_causal,
  28. flash::StaticPersistentTileScheduler,
  29. flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup, Kernel_traits::NumProducerThreads>
  30. >>;
  31. // using Scheduler = flash::SingleTileScheduler;
  32. Seqlen_traits seqlen_traits_q(
  33. params.total_q, params.seqlen_q, params.cu_seqlens_q, params.seqused_q);
  34. Seqlen_traits seqlen_traits_k(
  35. params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k);
  36. typename CollectiveMainloop::Params mainloop_params =
  37. CollectiveMainloop::to_underlying_arguments({
  38. static_cast<Element const*>(params.q_ptr),
  39. seqlen_traits_q.get_gmem_layout(
  40. params.seqlen_q, params.d, params.h, params.b,
  41. params.q_row_stride, params.q_head_stride, params.q_batch_stride
  42. ), // layout_Q
  43. static_cast<Element const*>(params.k_ptr),
  44. seqlen_traits_k.get_gmem_layout(
  45. params.seqlen_k, params.d, params.h_k, params.b,
  46. params.k_row_stride, params.k_head_stride, params.k_batch_stride
  47. ), // layout_K
  48. static_cast<Element const*>(params.v_ptr),
  49. seqlen_traits_k.get_gmem_layout(
  50. params.seqlen_k, params.d, params.h_k, params.b,
  51. params.v_row_stride, params.v_head_stride, params.v_batch_stride
  52. ), // layout_V
  53. params.scale_softmax_log2,
  54. params.descale_q_ptr,
  55. params.descale_k_ptr,
  56. params.descale_v_ptr
  57. });
  58. typename CollectiveEpilogue::Params epilogue_params =
  59. CollectiveEpilogue::to_underlying_arguments({
  60. static_cast<OutputType*>(params.o_ptr),
  61. seqlen_traits_q.get_gmem_layout(
  62. params.seqlen_q, params.d, params.h, params.b,
  63. params.o_row_stride, params.o_head_stride, params.o_batch_stride
  64. ), // layout_O
  65. static_cast<float*>(params.softmax_lse_ptr),
  66. seqlen_traits_q.get_lse_gmem_layout(
  67. params.seqlen_q, params.h, params.b
  68. ) // layout_LSE
  69. });
  70. int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
  71. num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
  72. typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b, params.tile_count_semaphore};
  73. typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
  74. // Get the ptr to kernel function.
  75. void *kernel;
  76. if constexpr(cutlass::sizeof_bits_v<Element> == 8)
  77. kernel = (void *)flash::compute_attn_ws_fp8<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
  78. else
  79. kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
  80. int smem_size = sizeof(typename Kernel_traits::SharedStorage);
  81. // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
  82. // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
  83. // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
  84. // int smem_size_o = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_o));
  85. // printf("smem_size = %d, q = %d, k = %d, v = %d, o = %d.\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_o);
  86. if (smem_size >= 48 * 1024) {
  87. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  88. }
  89. int device;
  90. cudaGetDevice(&device);
  91. int multiprocessor_count;
  92. CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
  93. dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
  94. static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
  95. dim3 block_dims(ctaSize);
  96. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  97. cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
  98. cutlass::launch_kernel_on_cluster(
  99. launch_params, kernel, mainloop_params, epilogue_params,
  100. scheduler_params, seqlen_traits_q, seqlen_traits_k);
  101. CHECK_CUDA_KERNEL_LAUNCH();
  102. }
  103. template<typename T>
  104. void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
  105. constexpr static int Headdim = 64;
  106. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  107. SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  108. run_flash_fwd<
  109. Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>,
  110. Is_causal, Seqlen_traits
  111. >(params, stream);
  112. });
  113. });
  114. }
  115. template<typename T>
  116. void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
  117. constexpr static int Headdim = 128;
  118. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  119. SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  120. // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
  121. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
  122. run_flash_fwd<
  123. Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>,
  124. Is_causal, Seqlen_traits
  125. >(params, stream);
  126. });
  127. });
  128. });
  129. }
  130. template<typename T>
  131. void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
  132. constexpr static int Headdim = 256;
  133. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  134. SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  135. // Only use Cluster if number of tiles along seqlen_q is even
  136. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
  137. run_flash_fwd<
  138. Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>,
  139. Is_causal, Seqlen_traits
  140. >(params, stream);
  141. });
  142. });
  143. });
  144. }
  145. template<typename T>
  146. void run_mha_fwd_hdim64_fp8(Flash_fwd_params &params, cudaStream_t stream) {
  147. constexpr static int Headdim = 64;
  148. constexpr static int kBlockM = 192;
  149. constexpr static int kBlockN = 128;
  150. constexpr static int kNWarps = 4 + kBlockM/16;
  151. constexpr static int kStages = 4;
  152. using Seqlen_traits = flash::FixedSeqLenTraits;
  153. if(params.is_causal) {
  154. run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
  155. false, 1, T>, /*Is_causal=*/true, Seqlen_traits>(params, stream);
  156. } else {
  157. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
  158. run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
  159. false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, Seqlen_traits>(params, stream);
  160. });
  161. }
  162. // BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  163. // SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  164. // Only use Cluster if number of tiles along seqlen_q is even
  165. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
  166. // !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
  167. // run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
  168. // false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
  169. // });
  170. // });
  171. // });
  172. }
  173. template<typename T>
  174. void run_mha_fwd_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) {
  175. constexpr static int Headdim = 128;
  176. constexpr static int kBlockM = 128;
  177. constexpr static int kBlockN = 256;
  178. constexpr static int kNWarps = 4 + kBlockM/16;
  179. constexpr static int kStages = 2;
  180. using Seqlen_traits = flash::FixedSeqLenTraits;
  181. if(params.is_causal) {
  182. run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
  183. false, 1, T>, /*Is_causal=*/true, Seqlen_traits>(params, stream);
  184. } else {
  185. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
  186. run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
  187. false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, Seqlen_traits>(params, stream);
  188. });
  189. }
  190. // BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  191. // SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  192. // Only use Cluster if number of tiles along seqlen_q is even
  193. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
  194. // !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
  195. // run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
  196. // false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
  197. // });
  198. // });
  199. // });
  200. }
  201. template<typename T>
  202. void run_mha_fwd_hdim256_fp8(Flash_fwd_params &params, cudaStream_t stream) {
  203. constexpr static int Headdim = 256;
  204. constexpr static int kBlockM = 128;
  205. constexpr static int kBlockN = 128;
  206. constexpr static int kNWarps = 4 + kBlockM/16;
  207. constexpr static int kStages = 2;
  208. using Seqlen_traits = flash::FixedSeqLenTraits;
  209. if(params.is_causal) {
  210. run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
  211. false, 1, T>, /*Is_causal=*/true, Seqlen_traits>(params, stream);
  212. } else {
  213. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
  214. run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
  215. false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, Seqlen_traits>(params, stream);
  216. });
  217. }
  218. // BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  219. // SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  220. // Only use Cluster if number of tiles along seqlen_q is even
  221. // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
  222. // !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
  223. // run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
  224. // false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
  225. // });
  226. // });
  227. // });
  228. }