flash_fwd_launch_template.h 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cutlass.h"
  7. #include "cutlass/cluster_launch.hpp"
  8. #include "static_switch.h"
  9. #include "flash.h"
  10. #include "tile_scheduler.hpp"
  11. #include "flash_fwd_kernel.h"
  12. #include "kernel_traits.h"
  13. #include "seq_len.h"
  14. #include "utils.h"
  15. template<typename Kernel_traits, bool Is_causal, typename Seqlen_traits>
  16. void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  17. using Element = typename Kernel_traits::Element;
  18. using ElementO = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(cutlass::half_t{}, Element{}));
  19. using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
  20. using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
  21. // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
  22. using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Seqlen_traits>;
  23. using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Seqlen_traits>;
  24. using Scheduler = std::conditional_t<
  25. Seqlen_traits::kUseVarSeqLen,
  26. flash::SingleTileScheduler,
  27. std::conditional_t<!Is_causal,
  28. flash::StaticPersistentTileScheduler,
  29. flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup>
  30. >>;
  31. // using Scheduler = flash::SingleTileScheduler;
  32. Seqlen_traits seqlen_traits_q(
  33. params.total_q, params.seqlen_q, params.cu_seqlens_q);
  34. Seqlen_traits seqlen_traits_k(
  35. params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k);
  36. typename CollectiveMainloop::Params mainloop_params =
  37. CollectiveMainloop::to_underlying_arguments({
  38. static_cast<Element const*>(params.q_ptr),
  39. seqlen_traits_q.get_gmem_layout(
  40. params.seqlen_q, params.d, params.h, params.b,
  41. params.q_row_stride, params.q_head_stride, params.q_batch_stride
  42. ), // layout_Q
  43. static_cast<Element const*>(params.k_ptr),
  44. seqlen_traits_k.get_gmem_layout(
  45. params.seqlen_k, params.d, params.h_k, params.b,
  46. params.k_row_stride, params.k_head_stride, params.k_batch_stride
  47. ), // layout_K
  48. static_cast<Element const*>(params.v_ptr),
  49. seqlen_traits_k.get_gmem_layout(
  50. params.seqlen_k, params.d, params.h_k, params.b,
  51. params.v_row_stride, params.v_head_stride, params.v_batch_stride
  52. ), // layout_V
  53. params.scale_softmax_log2
  54. });
  55. typename CollectiveEpilogue::Params epilogue_params =
  56. CollectiveEpilogue::to_underlying_arguments({
  57. static_cast<Element*>(params.o_ptr),
  58. seqlen_traits_q.get_gmem_layout(
  59. params.seqlen_q, params.d, params.h, params.b,
  60. params.o_row_stride, params.o_head_stride, params.o_batch_stride
  61. ), // layout_O
  62. static_cast<float*>(params.softmax_lse_ptr),
  63. seqlen_traits_q.get_lse_gmem_layout(
  64. params.seqlen_q, params.h, params.b
  65. ) // layout_LSE
  66. });
  67. int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
  68. num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
  69. typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b, params.tile_count_semaphore};
  70. typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
  71. // Get the ptr to kernel function.
  72. void *kernel;
  73. kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
  74. int smem_size = sizeof(typename Kernel_traits::SharedStorage);
  75. // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
  76. // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
  77. // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
  78. // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
  79. if (smem_size >= 48 * 1024) {
  80. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  81. }
  82. int device;
  83. cudaGetDevice(&device);
  84. int multiprocessor_count;
  85. cudaError status_ = cudaDeviceGetAttribute(
  86. &multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
  87. if (status_ != cudaSuccess) {
  88. CHECK_CUDA(status_);
  89. }
  90. dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
  91. static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
  92. dim3 block_dims(ctaSize);
  93. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  94. cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
  95. cutlass::launch_kernel_on_cluster(
  96. launch_params, kernel, mainloop_params, epilogue_params,
  97. scheduler_params, seqlen_traits_q, seqlen_traits_k);
  98. CHECK_CUDA_KERNEL_LAUNCH();
  99. }
  100. template<typename T>
  101. void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
  102. constexpr static int Headdim = 64;
  103. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  104. SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  105. run_flash_fwd<
  106. Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>,
  107. Is_causal, Seqlen_traits
  108. >(params, stream);
  109. });
  110. });
  111. }
  112. template<typename T>
  113. void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
  114. constexpr static int Headdim = 128;
  115. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  116. SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  117. // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
  118. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
  119. if constexpr (is_same_v<T, cutlass::float_e4m3_t>) {
  120. //run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 3, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
  121. //run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
  122. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 12, 4, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
  123. //run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 12, 4, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
  124. } else {
  125. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
  126. }
  127. });
  128. });
  129. });
  130. }
  131. template<typename T>
  132. void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
  133. constexpr static int Headdim = 256;
  134. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  135. SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
  136. // Only use Cluster if number of tiles along seqlen_q is even
  137. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
  138. if constexpr (is_same_v<T, cutlass::float_e4m3_t>) {
  139. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 12, 3, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
  140. } else {
  141. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
  142. }
  143. });
  144. });
  145. });
  146. }