flash_fwd_launch_template.h 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <ATen/cuda/CUDAContext.h>
  6. #include "cute/tensor.hpp"
  7. #include "cutlass/cutlass.h"
  8. #include "cutlass/cluster_launch.hpp"
  9. #include "static_switch.h"
  10. #include "flash.h"
  11. #include "tile_scheduler.hpp"
  12. #include "flash_fwd_kernel.h"
  13. #include "kernel_traits.h"
  14. template<typename Kernel_traits, bool Is_causal>
  15. void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  16. using Element = typename Kernel_traits::Element;
  17. using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
  18. using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
  19. // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
  20. BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
  21. using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Varlen>;
  22. // using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits>;
  23. using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Varlen>;
  24. using Scheduler = std::conditional_t<Varlen,
  25. flash::SingleTileScheduler,
  26. std::conditional_t<!Is_causal,
  27. flash::StaticPersistentTileScheduler,
  28. flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup>>
  29. >;
  30. // using Scheduler = flash::SingleTileScheduler;
  31. typename CollectiveMainloop::Params mainloop_params =
  32. CollectiveMainloop::to_underlying_arguments({
  33. static_cast<Element const*>(params.q_ptr),
  34. {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_Q
  35. {params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride}, // stride_Q
  36. static_cast<Element const*>(params.k_ptr),
  37. {!Varlen ? params.seqlen_k : params.total_k, params.d, params.h_k, !Varlen ? params.b : 1}, // shape_K
  38. {params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride}, // stride_K
  39. static_cast<Element const*>(params.v_ptr),
  40. {params.v_row_stride, _1{}, params.v_head_stride, params.v_batch_stride}, // stride_V
  41. params.scale_softmax,
  42. params.cu_seqlens_q, params.cu_seqlens_k,
  43. });
  44. typename CollectiveEpilogue::Params epilogue_params =
  45. CollectiveEpilogue::to_underlying_arguments({
  46. static_cast<Element*>(params.o_ptr),
  47. {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_O
  48. {params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride}, // stride_O
  49. static_cast<float*>(params.softmax_lse_ptr),
  50. {_1{}, !Varlen ? params.seqlen_q : params.total_q, params.h * (!Varlen ? params.seqlen_q : params.total_q)}, // stride_LSE
  51. params.cu_seqlens_q
  52. });
  53. int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
  54. num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
  55. typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b, params.tile_count_semaphore};
  56. typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
  57. // Get the ptr to kernel function.
  58. void *kernel;
  59. kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Varlen, Scheduler>;
  60. int smem_size = sizeof(typename Kernel_traits::SharedStorage);
  61. // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
  62. // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
  63. // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
  64. // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
  65. if (smem_size >= 48 * 1024) {
  66. C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  67. }
  68. int device;
  69. cudaGetDevice(&device);
  70. int multiprocessor_count;
  71. cudaError status_ = cudaDeviceGetAttribute(
  72. &multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
  73. if (status_ != cudaSuccess) {
  74. C10_CUDA_CHECK(status_);
  75. }
  76. dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
  77. static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
  78. dim3 block_dims(ctaSize);
  79. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  80. cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
  81. cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params);
  82. });
  83. C10_CUDA_KERNEL_LAUNCH_CHECK();
  84. }
  85. template<typename T>
  86. void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
  87. constexpr static int Headdim = 64;
  88. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  89. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>, Is_causal>(params, stream);
  90. });
  91. }
  92. template<typename T>
  93. void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
  94. constexpr static int Headdim = 128;
  95. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  96. // Only use Cluster if number of tiles along seqlen_q is even and not varlen
  97. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && params.cu_seqlens_q == nullptr, UseCluster, [&] {
  98. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
  99. });
  100. });
  101. }
  102. template<typename T>
  103. void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
  104. constexpr static int Headdim = 256;
  105. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  106. // Only use Cluster if number of tiles along seqlen_q is even and not varlen
  107. BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && params.cu_seqlens_q == nullptr, UseCluster, [&] {
  108. run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
  109. });
  110. });
  111. }