1
0

flash_fwd_launch_template.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cutlass.h"
  7. #include "cutlass/device_kernel.h" // For device_kernel
  8. #include <cutlass/kernel_hardware_info.h>
  9. #include "cutlass/cluster_launch.hpp"
  10. #include "static_switch.h"
  11. #include "flash.h"
  12. #include "tile_size.h"
  13. #include "tile_scheduler.hpp"
  14. #include "flash_fwd_kernel_sm90.h"
  15. #include "flash_fwd_kernel_sm80.h"
  16. #include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
  17. #include "mainloop_fwd_sm80.hpp"
  18. #include "epilogue_fwd.hpp"
  19. using namespace cute;
  20. template <int Arch, int kHeadDim, int ClusterM, typename Element, typename ElementOut,
  21. bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool PagedKV, bool AppendKV,
  22. bool PackGQA, bool Split, bool V_colmajor>
  23. void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  24. static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time");
  25. static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time");
  26. static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen");
  27. static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;
  28. static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;
  29. using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
  30. // Can't use structured binding since it's not compatible with constexpr
  31. static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKV, Has_softcap);
  32. static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV);
  33. static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);
  34. static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);
  35. static constexpr bool Mma1_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);
  36. static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap);
  37. static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS);
  38. static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS);
  39. static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS);
  40. using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  41. using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;
  42. using CollectiveMainloop = std::conditional_t<
  43. Arch >= 90,
  44. flash::CollectiveMainloopFwdSm90<kStages, ClusterShape, TileShape_MNK, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKV, AppendKV, Mma1_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor>,
  45. flash::CollectiveMainloopFwdSm80<kNWarps, kStages, Q_in_regs, TileShape_MNK, Element, float, cutlass::arch::Sm80, Is_causal, Is_local, Has_softcap, Varlen, PagedKV, AppendKV, PackGQA, Split>
  46. >;
  47. using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK, ClusterShape, ElementOut, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, PackGQA, FP8_TransposeV>;
  48. static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads;
  49. using SchedulerPersistent = std::conditional_t<Varlen,
  50. flash::VarlenDynamicPersistentTileScheduler<kBlockM, CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>,
  51. std::conditional_t<!Is_causal && !Is_local,
  52. flash::StaticPersistentTileScheduler<Split>,
  53. flash::DynamicPersistentTileScheduler<CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>
  54. >
  55. >;
  56. using SchedulerSingleTile = flash::SingleTileScheduler<Varlen, Split, PackGQA, kBlockM>;
  57. // If Split then we probably don't have enough work for PersistentScheduler to be useful.
  58. // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better
  59. // since we'll avoid launching a bunch of thread blocks that immediately exit.
  60. // On Sm80, noncausal persistent seems a bit slower.
  61. using Scheduler = std::conditional_t<Arch >= 90 ? (Split && !Varlen) : !((Is_causal && !Varlen) || (Varlen && Split)), SchedulerSingleTile, SchedulerPersistent>;
  62. using AttnKernel = std::conditional_t<
  63. Arch >= 90,
  64. flash::enable_sm90_or_later<flash::FlashAttnFwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
  65. flash::enable_sm80_to_sm89<flash::FlashAttnFwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
  66. >;
  67. bool const is_varlen_q = params.cu_seqlens_q;
  68. bool const is_varlen_k = params.cu_seqlens_k;
  69. bool const is_varlen_k_new = params.cu_seqlens_knew;
  70. int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
  71. int batch_q = !is_varlen_q ? params.b : 1;
  72. int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1;
  73. typename CollectiveMainloop::StrideV v_strides =
  74. cute::conditional_return<!V_colmajor>(
  75. make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),
  76. make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));
  77. typename CollectiveMainloop::Arguments mainloop_args {
  78. static_cast<Element const*>(params.q_ptr),
  79. {seqlen_q, params.d, params.h, batch_q}, // shape_Q
  80. {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
  81. static_cast<Element*>(params.k_ptr),
  82. {!PagedKV ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size,
  83. params.d, params.h_k, !PagedKV ? batch_k : params.num_pages}, // shape_K
  84. {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
  85. static_cast<Element*>(params.v_ptr),
  86. v_strides, // stride_V
  87. static_cast<Element const*>(params.knew_ptr),
  88. {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new
  89. {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new
  90. static_cast<Element const*>(params.vnew_ptr),
  91. {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new
  92. static_cast<Element const*>(params.rotary_cos_ptr),
  93. {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter
  94. {params.rotary_dim / 2, _1{}}, // stride_rotary_cos
  95. static_cast<Element const*>(params.rotary_sin_ptr),
  96. {params.rotary_dim / 2, _1{}}, // stride_rotary_sin
  97. params.is_rotary_interleaved,
  98. params.page_table,
  99. // if page_size is not set, avoid dividing by zero
  100. {params.kv_batch_idx ? params.b_k : params.b, !PagedKV ? 0 : params.seqlen_k / params.page_size}, // shape_page_table
  101. {params.page_table_batch_stride, _1{}}, // stride_page_table
  102. params.scale_softmax,
  103. params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr,
  104. {params.q_descale_batch_stride, params.q_descale_head_stride},
  105. {params.k_descale_batch_stride, params.k_descale_head_stride},
  106. {params.v_descale_batch_stride, params.v_descale_head_stride},
  107. params.window_size_left, params.window_size_right, params.sink_token_length,
  108. params.softcap,
  109. params.num_splits,
  110. params.kv_batch_idx,
  111. params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
  112. params.seqused_q, params.seqused_k,
  113. params.leftpad_k,
  114. };
  115. typename CollectiveEpilogue::Arguments epilogue_args {
  116. static_cast<ElementOut*>(!Split ? params.o_ptr : params.oaccum_ptr),
  117. {seqlen_q, params.d, params.h, batch_q, params.num_splits}, // shape_O
  118. {!Split ? params.o_row_stride : params.oaccum_row_stride,
  119. _1{},
  120. !Split ? params.o_head_stride : params.oaccum_head_stride,
  121. !is_varlen_q ? (!Split ? params.o_batch_stride : params.oaccum_batch_stride) : 0,
  122. !Split ? 0 : params.oaccum_split_stride}, // stride_O
  123. static_cast<float*>(!Split ? params.softmax_lse_ptr : params.softmax_lseaccum_ptr),
  124. {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, !Split ? 0 : params.h * seqlen_q * batch_q}, // stride_LSE
  125. params.h_k,
  126. params.cu_seqlens_q, params.seqused_q
  127. };
  128. int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k);
  129. int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{}));
  130. num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{}));
  131. typename flash::TileSchedulerArguments scheduler_args {
  132. num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits,
  133. params.h / params.h_k,
  134. params.seqlen_q,
  135. params.seqlen_k, params.d, sizeof(Element),
  136. params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q
  137. };
  138. int device;
  139. CHECK_CUDA(cudaGetDevice(&device));
  140. typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
  141. mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
  142. });
  143. dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
  144. dim3 block_dims = AttnKernel::get_block_shape();
  145. int smem_size = AttnKernel::SharedStorageSize;
  146. // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
  147. // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
  148. // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
  149. // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
  150. // Get the ptr to kernel function.
  151. if constexpr (size(ClusterShape{}) > 1) {
  152. void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
  153. if (smem_size >= 48 * 1024) {
  154. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  155. }
  156. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  157. cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
  158. cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
  159. } else {
  160. auto kernel = cutlass::device_kernel<AttnKernel>;
  161. if (smem_size >= 48 * 1024) {
  162. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  163. }
  164. kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);
  165. }
  166. CHECK_CUDA_KERNEL_LAUNCH();
  167. }
  168. template<int Arch, typename T, int kHeadDim, bool Split, bool PagedKV, bool Has_softcap, bool PackGQA>
  169. void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream) {
  170. static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported");
  171. static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
  172. using T_out = std::conditional_t<!Split, std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>, float>;
  173. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  174. VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] {
  175. static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1;
  176. VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] {
  177. // Only needed here to decide if we should use cluster
  178. static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128;
  179. static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen;
  180. APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
  181. // Only use Cluster if number of tiles along seqlen_q is even and not varlen
  182. CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
  183. static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1;
  184. run_flash_fwd<Arch, kHeadDim, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKV, AppendKV && Varlen, PackGQA, Split, V_colmajor>(params, stream);
  185. });
  186. });
  187. });
  188. });
  189. });
  190. }