flash_bwd_launch_template.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cluster_launch.hpp"
  7. #include "cutlass/device_kernel.h" // For device_kernel
  8. #include "static_switch.h"
  9. #include "flash.h"
  10. #include "flash_bwd_preprocess_kernel.h"
  11. #include "flash_bwd_postprocess_kernel.h"
  12. #include "tile_scheduler_bwd.hpp"
  13. #include "mainloop_bwd_sm90_tma_gmma_ws.hpp"
  14. #include "epilogue_bwd_sm90_tma.hpp"
  15. #include "flash_bwd_kernel.h"
  16. using namespace cute;
  17. template <int kHeadDim, int kBlockM, int kBlockN, typename Element, bool Is_causal, bool Varlen, bool Deterministic,
  18. bool dKV_swapAB, bool dQ_swapAB, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
  19. void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  20. using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
  21. using ElementAccum = float;
  22. using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90, /*Clear_dQaccum=*/true, Varlen>;
  23. int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * 128, 128);
  24. typename PreprocessKernel::Arguments preprocess_args {
  25. static_cast<Element const*>(params.o_ptr),
  26. {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_O
  27. {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O
  28. static_cast<Element const*>(params.do_ptr),
  29. {params.do_row_stride, _1{}, params.do_head_stride, !Varlen ? params.do_batch_stride : 0}, // stride_dO
  30. static_cast<float*>(params.dsoftmax_sum),
  31. {!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, params.h, !Varlen ? params.b : 1}, // shape_dPsum
  32. {_1{}, !Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, !Varlen ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
  33. static_cast<float*>(params.softmax_lse_ptr),
  34. {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE
  35. static_cast<float*>(params.softmax_lse_log2_ptr),
  36. {_1{}, !Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, !Varlen ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
  37. static_cast<ElementAccum*>(params.dq_accum_ptr),
  38. {!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, params.d_rounded, params.h, !Varlen ? params.b : 1}, // shape_dQaccum
  39. {params.d_rounded, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQ
  40. params.b,
  41. params.dq_semaphore,
  42. params.cu_seqlens_q,
  43. params.seqused_q
  44. };
  45. typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
  46. int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
  47. dim3 grid_m(num_m_block, params.h, params.b);
  48. cutlass::device_kernel<PreprocessKernel><<<grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream>>>(preprocess_params);
  49. using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  50. using ClusterShape = cute::Shape<_1, Int<1>, _1>;
  51. static constexpr int Stages = 2;
  52. using CollectiveMainloop = flash::CollectiveMainloopBwd<Stages, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
  53. Is_causal, Varlen, Deterministic,
  54. dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>;
  55. using CollectiveEpilogue = flash::CollectiveEpilogueBwd<TileShape_MNK, Element, CollectiveMainloop::NumMmaThreads, Varlen>;
  56. using Scheduler = flash::SingleTileSchedulerBwd;
  57. using AttnKernel = flash::FlashAttnBwd<CollectiveMainloop, CollectiveEpilogue, Scheduler>;
  58. typename CollectiveMainloop::Arguments mainloop_args {
  59. static_cast<Element const*>(params.q_ptr),
  60. {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_Q
  61. {params.q_row_stride, _1{}, params.q_head_stride, !Varlen ? params.q_batch_stride : 0}, // stride_Q
  62. static_cast<Element const*>(params.k_ptr),
  63. {!Varlen ? params.seqlen_k : params.total_k, params.d, params.h_k, !Varlen ? params.b : 1}, // shape_K
  64. {params.k_row_stride, _1{}, params.k_head_stride, !Varlen ? params.k_batch_stride : 0}, // stride_K
  65. static_cast<Element const*>(params.v_ptr),
  66. {params.v_row_stride, _1{}, params.v_head_stride, !Varlen ? params.v_batch_stride : 0}, // stride_V
  67. static_cast<Element const*>(params.do_ptr),
  68. {params.do_row_stride, _1{}, params.do_head_stride, !Varlen ? params.do_batch_stride : 0}, // stride_dO
  69. static_cast<ElementAccum*>(params.dq_accum_ptr),
  70. // {params.seqlen_q_rounded, params.d_rounded, params.h, params.b}, // shape_dQaccum
  71. // {params.d_rounded, _1{}, params.d_rounded * params.seqlen_q_rounded, params.d_rounded * params.seqlen_q_rounded * params.h}, // stride_dQaccum
  72. {(!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded) * (params.d_rounded / 32), 32, params.h, !Varlen ? params.b : 1}, // shape_dQaccum
  73. {32, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
  74. static_cast<float*>(params.softmax_lse_log2_ptr),
  75. {!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, params.h, !Varlen ? params.b : 1}, // shape_LSE
  76. {_1{}, !Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, !Varlen ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
  77. static_cast<float*>(params.dsoftmax_sum),
  78. {_1{}, !Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, !Varlen ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
  79. params.scale_softmax,
  80. params.b,
  81. params.dq_semaphore,
  82. params.cu_seqlens_q, params.cu_seqlens_k,
  83. params.seqused_q, params.seqused_k
  84. };
  85. typename CollectiveEpilogue::Arguments epilogue_args {
  86. static_cast<Element*>(params.dk_ptr),
  87. {!Varlen ? params.seqlen_k : params.total_k, params.d, params.h, !Varlen ? params.b : 1}, // shape_dK
  88. {params.dk_row_stride, _1{}, params.dk_head_stride, !Varlen ? params.dk_batch_stride : 0}, // stride_dK
  89. static_cast<Element*>(params.dv_ptr),
  90. {params.dv_row_stride, _1{}, params.dv_head_stride, !Varlen ? params.dv_batch_stride : 0},
  91. params.cu_seqlens_k
  92. };
  93. int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{}));
  94. num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{}));
  95. typename Scheduler::Arguments scheduler_args {
  96. num_blocks_n, params.h, params.b, params.tile_count_semaphore, params.cu_seqlens_k
  97. };
  98. int device;
  99. cudaGetDevice(&device);
  100. typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
  101. mainloop_args, epilogue_args, {device}, scheduler_args
  102. });
  103. // Get the ptr to kernel function.
  104. void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
  105. int smem_size = AttnKernel::SharedStorageSize;
  106. // int smem_size_q = sizeof(decltype((typename AttnKernel::SharedStorage{}).mainloop.smem_q));
  107. // int smem_size_do = sizeof(decltype((typename AttnKernel::SharedStorage{}).mainloop.smem_do));
  108. // int smem_size_ds = sizeof(decltype((typename AttnKernel::SharedStorage{}).mainloop.smem_ds));
  109. // int smem_size_dqacc = sizeof(decltype((typename AttnKernel::SharedStorage{}).mainloop.smem_dqacc));
  110. // int smem_size_k = sizeof(decltype((typename AttnKernel::SharedStorage{}).mainloop.smem_k));
  111. // int smem_size_v = sizeof(decltype((typename AttnKernel::SharedStorage{}).mainloop.smem_v));
  112. // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc);
  113. if (smem_size >= 48 * 1024) {
  114. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  115. }
  116. dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
  117. dim3 block_dims = AttnKernel::get_block_shape();
  118. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  119. cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
  120. cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
  121. CHECK_CUDA_KERNEL_LAUNCH();
  122. using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90,
  123. AttnKernel::CollectiveMainloop::kNThreadsdQ,
  124. typename AttnKernel::CollectiveMainloop::SmemLayoutdQaccumTMA,
  125. typename AttnKernel::CollectiveMainloop::TiledMmadQ,
  126. AttnKernel::CollectiveMainloop::dQ_swapAB
  127. >;
  128. typename PostprocessKernel::Arguments postprocess_args {
  129. static_cast<ElementAccum const*>(params.dq_accum_ptr),
  130. // {params.seqlen_q_rounded, params.d_rounded, params.h, params.b}, // shape_dQaccum
  131. // {params.d_rounded, _1{}, params.d_rounded * params.seqlen_q_rounded, params.d_rounded * params.seqlen_q_rounded * params.h}, // stride_dQaccum
  132. {(!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded) * (params.d_rounded / 32), 32, params.h, !Varlen ? params.b : 1}, // shape_dQaccum
  133. {32, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
  134. static_cast<Element*>(params.dq_ptr),
  135. {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_dQ
  136. {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
  137. params.scale_softmax,
  138. params.cu_seqlens_q,
  139. params.seqused_q
  140. };
  141. typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
  142. int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
  143. dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b);
  144. // Get the ptr to kernel function.
  145. auto postprocess_kernel = cutlass::device_kernel<PostprocessKernel>;
  146. int smem_size_postprocess = PostprocessKernel::SharedStorageSize;
  147. if (smem_size_postprocess >= 48 * 1024) {
  148. CHECK_CUDA(cudaFuncSetAttribute(postprocess_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  149. }
  150. postprocess_kernel<<<grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream>>>(postprocess_params);
  151. CHECK_CUDA_KERNEL_LAUNCH();
  152. }
  153. template<typename T>
  154. void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
  155. constexpr static int Headdim = 64;
  156. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  157. BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
  158. BOOL_SWITCH(params.deterministic, Deterministic, [&] {
  159. run_flash_bwd<Headdim, 128, 128, T, Is_causal, Varlen, Deterministic, false, false, 1, 2, 2>(params, stream);
  160. });
  161. });
  162. });
  163. }
  164. template<typename T>
  165. void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
  166. constexpr static int Headdim = 96;
  167. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  168. BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
  169. BOOL_SWITCH(params.deterministic, Deterministic, [&] {
  170. run_flash_bwd<Headdim, 64, 128, T, Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
  171. });
  172. });
  173. });
  174. }
  175. template<typename T>
  176. void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
  177. constexpr static int Headdim = 128;
  178. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  179. BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
  180. BOOL_SWITCH(params.deterministic, Deterministic, [&] {
  181. run_flash_bwd<Headdim, 64, 128, T, Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
  182. });
  183. });
  184. });
  185. }