flash_fwd_launch_template.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cutlass.h"
  7. #include "cutlass/device_kernel.h" // For device_kernel
  8. #include <cutlass/kernel_hardware_info.h>
  9. #include "cutlass/cluster_launch.hpp"
  10. #include "static_switch.h"
  11. #include "flash.h"
  12. #include "tile_scheduler.hpp"
  13. #include "flash_fwd_kernel.h"
  14. #include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
  15. #include "epilogue_fwd_sm90_tma.hpp"
  16. using namespace cute;
  17. template <int kHeadDim, int kBlockM, int kBlockN, int Stages, int ClusterM, typename Element, typename ElementOut,
  18. bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool V_colmajor>
  19. void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  20. static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time");
  21. static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;;
  22. static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;
  23. using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  24. using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;
  25. using CollectiveMainloop = flash::CollectiveMainloopFwd<Stages, ClusterShape, TileShape_MNK, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, V_colmajor>;
  26. using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK, ElementOut, CollectiveMainloop::NumMmaThreads, Varlen, FP8_TransposeV>;
  27. using Scheduler = std::conditional_t<Varlen,
  28. // flash::SingleTileScheduler<Varlen, kBlockM>,
  29. flash::VarlenDynamicPersistentTileScheduler<kBlockM, CollectiveMainloop::NumMmaThreads, CollectiveMainloop::NumProducerThreads>,
  30. std::conditional_t<!Is_causal && !Is_local,
  31. flash::StaticPersistentTileScheduler,
  32. flash::DynamicPersistentTileScheduler<CollectiveMainloop::NumMmaThreads, CollectiveMainloop::NumProducerThreads>>
  33. // flash::SingleTileScheduler<Varlen, kBlockM>>
  34. >;
  35. // using Scheduler = flash::SingleTileScheduler<Varlen, kBlockM>;
  36. using AttnKernel = std::conditional_t<!FP8_TransposeV,
  37. flash::FlashAttnFwd<CollectiveMainloop, CollectiveEpilogue, Scheduler>,
  38. flash::FlashAttnFwdFP8TransposeV<CollectiveMainloop, CollectiveEpilogue, Scheduler>
  39. >;
  40. typename CollectiveMainloop::StrideV v_strides =
  41. cute::conditional_return<!V_colmajor>(
  42. make_stride(params.v_row_stride, _1{}, params.v_head_stride, !Varlen ? params.v_batch_stride : 0),
  43. make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !Varlen ? params.v_batch_stride : 0));
  44. // print(typename CollectiveMainloop::SmemLayoutVTma{}); printf("\n");
  45. // print(typename CollectiveMainloop::SmemLayoutVMma{}); printf("\n");
  46. typename CollectiveMainloop::Arguments mainloop_args {
  47. static_cast<Element const*>(params.q_ptr),
  48. {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_Q
  49. {params.q_row_stride, _1{}, params.q_head_stride, !Varlen ? params.q_batch_stride : 0}, // stride_Q
  50. static_cast<Element const*>(params.k_ptr),
  51. {!Varlen ? params.seqlen_k : params.total_k, params.d, params.h_k, !Varlen ? params.b : 1}, // shape_K
  52. {params.k_row_stride, _1{}, params.k_head_stride, !Varlen ? params.k_batch_stride : 0}, // stride_K
  53. static_cast<Element const*>(params.v_ptr),
  54. v_strides, // stride_V
  55. params.scale_softmax,
  56. params.q_scale_ptr, params.k_scale_ptr, params.v_scale_ptr,
  57. params.window_size_left, params.window_size_right,
  58. params.softcap,
  59. params.cu_seqlens_q, params.cu_seqlens_k,
  60. params.seqused_q, params.seqused_k,
  61. };
  62. typename CollectiveEpilogue::Arguments epilogue_args {
  63. static_cast<ElementOut*>(params.o_ptr),
  64. {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_O
  65. {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O
  66. static_cast<float*>(params.softmax_lse_ptr),
  67. {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE
  68. params.cu_seqlens_q, params.seqused_q
  69. };
  70. int num_blocks_m = cutlass::ceil_div(params.seqlen_q, get<0>(TileShape_MNK{}));
  71. num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{}));
  72. typename Scheduler::Arguments scheduler_args {
  73. num_blocks_m, params.h, params.b, params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q
  74. };
  75. int device;
  76. CHECK_CUDA(cudaGetDevice(&device));
  77. typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
  78. mainloop_args, epilogue_args, {device}, scheduler_args
  79. });
  80. dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
  81. dim3 block_dims = AttnKernel::get_block_shape();
  82. int smem_size = AttnKernel::SharedStorageSize;
  83. // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
  84. // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
  85. // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
  86. // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
  87. // Get the ptr to kernel function.
  88. if constexpr (size(ClusterShape{}) > 1) {
  89. void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
  90. if (smem_size >= 48 * 1024) {
  91. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  92. }
  93. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  94. cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
  95. cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
  96. } else {
  97. auto kernel = cutlass::device_kernel<AttnKernel>;
  98. if (smem_size >= 48 * 1024) {
  99. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  100. }
  101. kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);
  102. }
  103. CHECK_CUDA_KERNEL_LAUNCH();
  104. }
  105. template<typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Enable_cluster>
  106. void run_mha_fwd_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
  107. BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
  108. // Only use Cluster if number of tiles along seqlen_q is even and not varlen
  109. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
  110. BOOL_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
  111. run_flash_fwd<kHeadDim, kBlockM, kBlockN, 2 /*Stages*/, !Is_causal && !Is_local && !Varlen && Enable_cluster && UseCluster ? 2 : 1, T, T, Is_causal, Is_local, Has_softcap, Varlen, false /*V_colmajor*/>(params, stream);
  112. });
  113. });
  114. });
  115. }
  116. template<typename T>
  117. void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
  118. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  119. run_mha_fwd_dispatch<T, 192, 128, 64, Is_causal, Is_local, false /*Enable_cluster*/>(params, stream);
  120. });
  121. }
  122. template<typename T>
  123. void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
  124. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  125. run_mha_fwd_dispatch<T, 128, Is_causal || Is_local ? 128 : 160, 96, Is_causal, Is_local, true /*Enable_cluster*/>(params, stream);
  126. });
  127. }
  128. template<typename T>
  129. void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
  130. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  131. run_mha_fwd_dispatch<T, 128, Is_causal || Is_local ? 128 : 176, 128, Is_causal, Is_local, true /*Enable_cluster*/>(params, stream);
  132. });
  133. }
  134. template<typename T>
  135. void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
  136. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  137. run_mha_fwd_dispatch<T, 128, 96, 192, Is_causal, Is_local, true /*Enable_cluster*/>(params, stream);
  138. });
  139. }
  140. template<typename T>
  141. void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
  142. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  143. run_mha_fwd_dispatch<T, 128, 80, 256, Is_causal, Is_local, true /*Enable_cluster*/>(params, stream);
  144. });
  145. }
  146. template<typename T, int kBlockM, int kBlockN, int kHeadDim, int kStages,
  147. bool Is_causal, bool Is_local, bool V_colmajor, bool Enable_cluster>
  148. void run_mha_fwd_fp8_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
  149. BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
  150. // Only use Cluster if number of tiles along seqlen_q is even and not varlen
  151. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
  152. run_flash_fwd<kHeadDim, kBlockM, kBlockN, kStages, !Is_causal && !Is_local && !Varlen && Enable_cluster && UseCluster ? 2 : 1, T, cutlass::bfloat16_t, Is_causal, Is_local, false /*Has_softcap*/, Varlen, V_colmajor && !Varlen>(params, stream);
  153. });
  154. });
  155. }
  156. template<typename T>
  157. void run_mha_fwd_fp8_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
  158. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  159. BOOL_SWITCH(params.v_dim_stride != 1, V_colmajor, [&] {
  160. run_mha_fwd_fp8_dispatch<T, 192, 160, 64, 3, Is_causal, Is_local, V_colmajor, false /*Enable_cluster*/>(params, stream);
  161. });
  162. });
  163. }
  164. template<typename T>
  165. void run_mha_fwd_fp8_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
  166. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  167. BOOL_SWITCH(params.v_dim_stride != 1, V_colmajor, [&] {
  168. run_mha_fwd_fp8_dispatch<T, 192, 128, 96, 3, Is_causal, Is_local, V_colmajor, false /*Enable_cluster*/>(params, stream);
  169. });
  170. });
  171. }
  172. template<typename T>
  173. void run_mha_fwd_fp8_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
  174. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  175. BOOL_SWITCH(params.v_dim_stride != 1, V_colmajor, [&] {
  176. run_mha_fwd_fp8_dispatch<T, 128, V_colmajor ? 192 : 224, 128, 2, Is_causal, Is_local, V_colmajor, true /*Enable_cluster*/>(params, stream);
  177. });
  178. });
  179. }
  180. template<typename T>
  181. void run_mha_fwd_fp8_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
  182. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  183. BOOL_SWITCH(params.v_dim_stride != 1, V_colmajor, [&] {
  184. run_mha_fwd_fp8_dispatch<T, 128, 160, 192, 2, Is_causal, Is_local, V_colmajor, true /*Enable_cluster*/>(params, stream);
  185. });
  186. });
  187. }
  188. template<typename T>
  189. void run_mha_fwd_fp8_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
  190. CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
  191. BOOL_SWITCH(params.v_dim_stride != 1, V_colmajor, [&] {
  192. run_mha_fwd_fp8_dispatch<T, 128, 128, 256, 2, Is_causal, Is_local, V_colmajor, true /*Enable_cluster*/>(params, stream);
  193. });
  194. });
  195. }