flash_fwd_launch_template.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cutlass.h"
  7. #include "cutlass/device_kernel.h" // For device_kernel
  8. #include <cutlass/kernel_hardware_info.h>
  9. #include "cutlass/cluster_launch.hpp"
  10. #include "static_switch.h"
  11. #include "flash.h"
  12. #include "tile_size.h"
  13. #include "tile_scheduler.hpp"
  14. #include "flash_fwd_kernel_sm90.h"
  15. #include "flash_fwd_kernel_sm80.h"
  16. #include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
  17. #include "mainloop_fwd_sm80.hpp"
  18. #include "epilogue_fwd.hpp"
  19. using namespace cute;
  20. template <int Arch, int kHeadDim, int kHeadDimV, int ClusterM, typename Element, typename ElementOut,
  21. bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool PagedKV, bool AppendKV, bool HasQv,
  22. bool PackGQA, bool Split, bool V_colmajor>
  23. void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  24. static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time");
  25. static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time");
  26. static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen");
  27. static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;
  28. static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;
  29. using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
  30. // Can't use structured binding since it's not compatible with constexpr
  31. static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKV, Has_softcap);
  32. static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV);
  33. static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);
  34. static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);
  35. static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);
  36. static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap);
  37. static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS);
  38. static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS);
  39. static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS);
  40. using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  41. using TileShape_MNK_PV = cute::Shape<Int<kBlockM>, Int<kHeadDimV>, Int<kBlockN>>;
  42. using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;
  43. using CollectiveMainloop = std::conditional_t<
  44. Arch >= 90,
  45. flash::CollectiveMainloopFwdSm90<kStages, ClusterShape, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKV, AppendKV, HasQv, MmaPV_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor>,
  46. flash::CollectiveMainloopFwdSm80<kNWarps, kStages, Q_in_regs, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm80, Is_causal, Is_local, Has_softcap, Varlen, PagedKV, AppendKV, PackGQA, Split>
  47. >;
  48. using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK_PV, ClusterShape, ElementOut, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, PackGQA, FP8_TransposeV>;
  49. static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads;
  50. using SchedulerPersistent = std::conditional_t<Varlen,
  51. flash::VarlenDynamicPersistentTileScheduler<kBlockM, CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>,
  52. std::conditional_t<!Is_causal && !Is_local,
  53. flash::StaticPersistentTileScheduler<Split>,
  54. flash::DynamicPersistentTileScheduler<CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>
  55. >
  56. >;
  57. using SchedulerSingleTile = flash::SingleTileScheduler<Varlen, Split, PackGQA, kBlockM>;
  58. // If Split then we probably don't have enough work for PersistentScheduler to be useful.
  59. // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better
  60. // since we'll avoid launching a bunch of thread blocks that immediately exit.
  61. // On Sm80, noncausal persistent seems a bit slower.
  62. static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split));
  63. using Scheduler = std::conditional_t<!UsePersistentScheduler, SchedulerSingleTile, SchedulerPersistent>;
  64. using AttnKernel = std::conditional_t<
  65. Arch >= 90,
  66. flash::enable_sm90_or_later<flash::FlashAttnFwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
  67. flash::enable_sm80_to_sm89<flash::FlashAttnFwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
  68. >;
  69. bool const is_varlen_q = params.cu_seqlens_q;
  70. bool const is_varlen_k = params.cu_seqlens_k;
  71. bool const is_varlen_k_new = params.cu_seqlens_knew;
  72. int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
  73. int batch_q = !is_varlen_q ? params.b : 1;
  74. int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1;
  75. typename CollectiveMainloop::StrideV v_strides =
  76. cute::conditional_return<!V_colmajor>(
  77. make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),
  78. make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));
  79. typename CollectiveMainloop::Arguments mainloop_args {
  80. static_cast<Element const*>(params.q_ptr),
  81. {seqlen_q, params.d, params.h, batch_q}, // shape_Q
  82. {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
  83. static_cast<Element*>(params.k_ptr),
  84. {!PagedKV ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size,
  85. params.d, params.h_k, !PagedKV ? batch_k : params.num_pages}, // shape_K
  86. {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
  87. static_cast<Element*>(params.v_ptr),
  88. params.dv, // headdim_v
  89. v_strides, // stride_V
  90. static_cast<Element const*>(params.knew_ptr),
  91. {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new
  92. {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new
  93. static_cast<Element const*>(params.vnew_ptr),
  94. {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new
  95. static_cast<Element const*>(params.qv_ptr),
  96. {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv
  97. static_cast<Element const*>(params.rotary_cos_ptr),
  98. {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter
  99. {params.rotary_dim / 2, _1{}}, // stride_rotary_cos
  100. static_cast<Element const*>(params.rotary_sin_ptr),
  101. {params.rotary_dim / 2, _1{}}, // stride_rotary_sin
  102. params.is_rotary_interleaved,
  103. params.page_table,
  104. // if page_size is not set, avoid dividing by zero
  105. {params.kv_batch_idx ? params.b_k : params.b, !PagedKV ? 0 : params.seqlen_k / params.page_size}, // shape_page_table
  106. {params.page_table_batch_stride, _1{}}, // stride_page_table
  107. params.scale_softmax,
  108. params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr,
  109. {params.q_descale_batch_stride, params.q_descale_head_stride},
  110. {params.k_descale_batch_stride, params.k_descale_head_stride},
  111. {params.v_descale_batch_stride, params.v_descale_head_stride},
  112. params.window_size_left, params.window_size_right,
  113. params.softcap,
  114. params.num_splits,
  115. params.kv_batch_idx,
  116. params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
  117. params.seqused_q, params.seqused_k,
  118. params.leftpad_k,
  119. };
  120. typename CollectiveEpilogue::Arguments epilogue_args {
  121. static_cast<ElementOut*>(!Split ? params.o_ptr : params.oaccum_ptr),
  122. {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O
  123. {!Split ? params.o_row_stride : params.oaccum_row_stride,
  124. _1{},
  125. !Split ? params.o_head_stride : params.oaccum_head_stride,
  126. !is_varlen_q ? (!Split ? params.o_batch_stride : params.oaccum_batch_stride) : 0,
  127. !Split ? 0 : params.oaccum_split_stride}, // stride_O
  128. static_cast<float*>(!Split ? params.softmax_lse_ptr : params.softmax_lseaccum_ptr),
  129. {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, !Split ? 0 : params.h * seqlen_q * batch_q}, // stride_LSE
  130. params.h_k,
  131. params.cu_seqlens_q, params.seqused_q
  132. };
  133. int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k);
  134. int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{}));
  135. num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{}));
  136. typename flash::TileSchedulerArguments scheduler_args {
  137. num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits,
  138. params.h / params.h_k,
  139. params.seqlen_q,
  140. params.seqlen_k, params.d, sizeof(Element),
  141. params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q,
  142. // params.num_m_blocks_ptr, params.num_splits_dynamic_ptr,
  143. params.num_splits_dynamic_ptr,
  144. };
  145. if constexpr (Varlen && UsePersistentScheduler) {
  146. prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN);
  147. CHECK_CUDA_KERNEL_LAUNCH();
  148. }
  149. int device;
  150. CHECK_CUDA(cudaGetDevice(&device));
  151. typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
  152. mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
  153. });
  154. dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
  155. dim3 block_dims = AttnKernel::get_block_shape();
  156. int smem_size = AttnKernel::SharedStorageSize;
  157. // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
  158. // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
  159. // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
  160. // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
  161. // Get the ptr to kernel function.
  162. if constexpr (size(ClusterShape{}) > 1) {
  163. void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
  164. if (smem_size >= 48 * 1024) {
  165. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  166. }
  167. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  168. cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
  169. cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
  170. } else {
  171. auto kernel = cutlass::device_kernel<AttnKernel>;
  172. if (smem_size >= 48 * 1024) {
  173. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  174. }
  175. kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);
  176. }
  177. CHECK_CUDA_KERNEL_LAUNCH();
  178. }
  179. template<int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKV, bool Has_softcap, bool PackGQA>
  180. void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream) {
  181. static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported");
  182. static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
  183. using T_out = std::conditional_t<!Split, std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>, float>;
  184. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  185. VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] {
  186. static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1;
  187. VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] {
  188. // Only needed here to decide if we should use cluster
  189. static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128;
  190. // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS)
  191. static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen;
  192. BOOL_SWITCH(params.qv_ptr, HasQV_, [&] {
  193. static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV == 512;
  194. APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
  195. // Only use Cluster if number of tiles along seqlen_q is even and not varlen
  196. CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
  197. static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1;
  198. run_flash_fwd<Arch, kHeadDim, kHeadDimV, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKV, AppendKV && Varlen, HasQv, PackGQA, Split, V_colmajor>(params, stream);
  199. });
  200. });
  201. });
  202. });
  203. });
  204. });
  205. }