flash_fwd_launch_template.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cutlass.h"
  7. #include "cutlass/device_kernel.h" // For device_kernel
  8. #include <cutlass/kernel_hardware_info.h>
  9. #include "cutlass/cluster_launch.hpp"
  10. #include "static_switch.h"
  11. #include "flash.h"
  12. #include "tile_size.h"
  13. #include "tile_scheduler.hpp"
  14. #include "flash_fwd_kernel.h"
  15. #include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
  16. #include "epilogue_fwd_sm90_tma.hpp"
  17. using namespace cute;
  18. template <int kHeadDim, int kBlockM, int kBlockN, int kStages, int ClusterM, typename Element, typename ElementOut,
  19. bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool PagedKV, bool AppendKV,
  20. bool Mma1_is_RS, bool IntraWGOverlap, bool PackGQA, bool Split, bool V_colmajor>
  21. void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  22. static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time");
  23. static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time");
  24. static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen");
  25. static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;
  26. static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;
  27. using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  28. using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;
  29. using CollectiveMainloop = flash::CollectiveMainloopFwd<kStages, ClusterShape, TileShape_MNK, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKV, AppendKV, Mma1_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor>;
  30. using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK, ClusterShape, ElementOut, CollectiveMainloop::NumMmaThreads, Varlen, PackGQA, FP8_TransposeV>;
  31. using SchedulerPersistent = std::conditional_t<Varlen,
  32. flash::VarlenDynamicPersistentTileScheduler<kBlockM, CollectiveMainloop::NumMmaThreads, CollectiveMainloop::NumProducerThreads, Split, PackGQA>,
  33. std::conditional_t<!Is_causal && !Is_local,
  34. flash::StaticPersistentTileScheduler<Split>,
  35. flash::DynamicPersistentTileScheduler<CollectiveMainloop::NumMmaThreads, CollectiveMainloop::NumProducerThreads, Split, PackGQA>
  36. >
  37. >;
  38. using SchedulerSingleTile = flash::SingleTileScheduler<Varlen, Split, PackGQA, kBlockM>;
  39. // If Split, PagedKV, or AppendKV then we probably don't have enough work for PersistentScheduler to be useful.
  40. using Scheduler = std::conditional_t<Split || PagedKV || AppendKV, SchedulerSingleTile, SchedulerPersistent>;
  41. using AttnKernel = flash::FlashAttnFwd<CollectiveMainloop, CollectiveEpilogue, Scheduler>;
  42. bool const is_varlen_q = params.cu_seqlens_q;
  43. bool const is_varlen_k = params.cu_seqlens_k;
  44. bool const is_varlen_k_new = params.cu_seqlens_knew;
  45. int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
  46. int batch_q = !is_varlen_q ? params.b : 1;
  47. int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1;
  48. typename CollectiveMainloop::StrideV v_strides =
  49. cute::conditional_return<!V_colmajor>(
  50. make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),
  51. make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));
  52. typename CollectiveMainloop::Arguments mainloop_args {
  53. static_cast<Element const*>(params.q_ptr),
  54. {seqlen_q, params.d, params.h, batch_q}, // shape_Q
  55. {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
  56. static_cast<Element*>(params.k_ptr),
  57. {!PagedKV ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size,
  58. params.d, params.h_k, !PagedKV ? batch_k : params.num_pages}, // shape_K
  59. {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
  60. static_cast<Element*>(params.v_ptr),
  61. v_strides, // stride_V
  62. static_cast<Element const*>(params.knew_ptr),
  63. {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new
  64. {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new
  65. static_cast<Element const*>(params.vnew_ptr),
  66. {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new
  67. static_cast<Element const*>(params.rotary_cos_ptr),
  68. {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter
  69. {params.rotary_dim / 2, _1{}}, // stride_rotary_cos
  70. static_cast<Element const*>(params.rotary_sin_ptr),
  71. {params.rotary_dim / 2, _1{}}, // stride_rotary_sin
  72. params.is_rotary_interleaved,
  73. params.page_table,
  74. // if page_size is not set, avoid dividing by zero
  75. {params.kv_batch_idx ? params.b_k : params.b, !PagedKV ? 0 : params.seqlen_k / params.page_size}, // shape_page_table
  76. {params.page_table_batch_stride, _1{}}, // stride_page_table
  77. params.scale_softmax,
  78. params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr,
  79. {params.q_descale_batch_stride, params.q_descale_head_stride},
  80. {params.k_descale_batch_stride, params.k_descale_head_stride},
  81. {params.v_descale_batch_stride, params.v_descale_head_stride},
  82. params.window_size_left, params.window_size_right, params.sink_token_length,
  83. params.softcap,
  84. params.num_splits,
  85. params.kv_batch_idx,
  86. params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
  87. params.seqused_q, params.seqused_k,
  88. params.leftpad_k,
  89. };
  90. typename CollectiveEpilogue::Arguments epilogue_args {
  91. static_cast<ElementOut*>(!Split ? params.o_ptr : params.oaccum_ptr),
  92. {seqlen_q, params.d, params.h, batch_q, params.num_splits}, // shape_O
  93. {!Split ? params.o_row_stride : params.oaccum_row_stride,
  94. _1{},
  95. !Split ? params.o_head_stride : params.oaccum_head_stride,
  96. !is_varlen_q ? (!Split ? params.o_batch_stride : params.oaccum_batch_stride) : 0,
  97. !Split ? 0 : params.oaccum_split_stride}, // stride_O
  98. static_cast<float*>(!Split ? params.softmax_lse_ptr : params.softmax_lseaccum_ptr),
  99. {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, !Split ? 0 : params.h * seqlen_q * batch_q}, // stride_LSE
  100. params.h_k,
  101. params.cu_seqlens_q, params.seqused_q
  102. };
  103. int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k);
  104. int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{}));
  105. num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{}));
  106. typename flash::TileSchedulerArguments scheduler_args {
  107. num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits,
  108. params.h / params.h_k,
  109. params.seqlen_q,
  110. params.seqlen_k, params.d, sizeof(Element),
  111. params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q
  112. };
  113. int device;
  114. CHECK_CUDA(cudaGetDevice(&device));
  115. typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
  116. mainloop_args, epilogue_args, {device}, scheduler_args
  117. });
  118. dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
  119. dim3 block_dims = AttnKernel::get_block_shape();
  120. int smem_size = AttnKernel::SharedStorageSize;
  121. // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
  122. // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
  123. // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
  124. // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
  125. // Get the ptr to kernel function.
  126. if constexpr (size(ClusterShape{}) > 1) {
  127. void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
  128. if (smem_size >= 48 * 1024) {
  129. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  130. }
  131. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  132. cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
  133. cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
  134. } else {
  135. auto kernel = cutlass::device_kernel<AttnKernel>;
  136. if (smem_size >= 48 * 1024) {
  137. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  138. }
  139. kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);
  140. }
  141. CHECK_CUDA_KERNEL_LAUNCH();
  142. }
  143. template<typename T, int kBlockM, int kBlockN, int kHeadDim, int kStages,
  144. bool Is_causal, bool Is_local, bool Has_softcap, bool PagedKV, bool Mma1_is_RS, bool IntraWGOverlap,
  145. bool Split, bool V_colmajor, bool Enable_cluster>
  146. void run_mha_fwd_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
  147. static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
  148. using T_out = std::conditional_t<!Split, std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>, float>;
  149. VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] {
  150. APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
  151. PACKGQA_SWITCH(params.pack_gqa, PackGQA, [&] {
  152. // Only use Cluster if number of tiles along seqlen_q is even and not varlen
  153. CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
  154. static constexpr int ClusterM = !Varlen && Enable_cluster && Use_cluster ? 2 : 1;
  155. run_flash_fwd<kHeadDim, kBlockM, kBlockN, kStages, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKV, AppendKV && Varlen, Mma1_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor>(params, stream);
  156. });
  157. });
  158. });
  159. });
  160. }
  161. template<typename T, int kHeadDim, bool Split, bool PagedKV>
  162. void run_mha_fwd_16b(Flash_fwd_params &params, cudaStream_t stream) {
  163. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  164. SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
  165. // Can't use structured binding since it's not compatible with constexpr
  166. static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd(kHeadDim, Is_causal, Is_local, sizeof(T) /*element_size*/, false /*V_colmajor*/, PagedKV, Has_softcap);
  167. static constexpr bool Enable_cluster = kHeadDim >= 128 && !Is_causal && !Is_local && !Split && !PagedKV;
  168. run_mha_fwd_dispatch<T, std::get<0>(kBlockMN_RS_IntraWGOverlap), std::get<1>(kBlockMN_RS_IntraWGOverlap), kHeadDim, 2,
  169. Is_causal, Is_local, Has_softcap, PagedKV, std::get<2>(kBlockMN_RS_IntraWGOverlap), std::get<3>(kBlockMN_RS_IntraWGOverlap), Split, false /*V_colmajor*/, Enable_cluster>(params, stream);
  170. });
  171. });
  172. }
  173. template<typename T, int kHeadDim, bool Split, bool PagedKV>
  174. void run_mha_fwd_8b(Flash_fwd_params &params, cudaStream_t stream) {
  175. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  176. VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor, [&] {
  177. SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
  178. // Can't use structured binding since it's not compatible with constexpr
  179. static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd(kHeadDim, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor /*V_colmajor*/, PagedKV, Has_softcap);
  180. static constexpr bool Enable_cluster = kHeadDim == 192 && !Is_causal && !Is_local && !Split && !PagedKV;
  181. run_mha_fwd_dispatch<T, std::get<0>(kBlockMN_RS_IntraWGOverlap), std::get<1>(kBlockMN_RS_IntraWGOverlap), kHeadDim, 2,
  182. Is_causal, Is_local, Has_softcap, PagedKV, std::get<2>(kBlockMN_RS_IntraWGOverlap), std::get<3>(kBlockMN_RS_IntraWGOverlap), Split, V_colmajor, Enable_cluster>(params, stream);
  183. });
  184. });
  185. });
  186. }