flash_bwd_launch_template.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cluster_launch.hpp"
  7. #include "cutlass/device_kernel.h" // For device_kernel
  8. #include "static_switch.h"
  9. #include "flash.h"
  10. #include "flash_bwd_preprocess_kernel.h"
  11. #include "flash_bwd_postprocess_kernel.h"
  12. #include "tile_scheduler_bwd.hpp"
  13. #include "mainloop_bwd_sm90_tma_gmma_ws.hpp"
  14. #include "epilogue_bwd_sm90_tma.hpp"
  15. #include "flash_bwd_kernel.h"
  16. using namespace cute;
  17. template <int kHeadDim, int kBlockM, int kBlockN, typename Element, bool Is_causal, bool Is_local, bool Varlen, bool Deterministic,
  18. bool dKV_swapAB, bool dQ_swapAB, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
  19. void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  20. static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
  21. using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
  22. using ElementAccum = float;
  23. using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90, /*Clear_dQaccum=*/true, Varlen>;
  24. int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * 128, 128);
  25. typename PreprocessKernel::Arguments preprocess_args {
  26. static_cast<Element const*>(params.o_ptr),
  27. {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_O
  28. {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O
  29. static_cast<Element const*>(params.do_ptr),
  30. {params.do_row_stride, _1{}, params.do_head_stride, !Varlen ? params.do_batch_stride : 0}, // stride_dO
  31. static_cast<float*>(params.dsoftmax_sum),
  32. {!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, params.h, !Varlen ? params.b : 1}, // shape_dPsum
  33. {_1{}, !Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, !Varlen ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
  34. static_cast<float*>(params.softmax_lse_ptr),
  35. {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE
  36. static_cast<float*>(params.softmax_lse_log2_ptr),
  37. {_1{}, !Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, !Varlen ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
  38. static_cast<ElementAccum*>(params.dq_accum_ptr),
  39. {!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, params.d_rounded, params.h, !Varlen ? params.b : 1}, // shape_dQaccum
  40. {params.d_rounded, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQ
  41. params.b,
  42. params.dq_semaphore,
  43. params.cu_seqlens_q,
  44. params.seqused_q
  45. };
  46. typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
  47. int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
  48. dim3 grid_m(num_m_block, params.h, params.b);
  49. cutlass::device_kernel<PreprocessKernel><<<grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream>>>(preprocess_params);
  50. using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  51. using ClusterShape = cute::Shape<_1, Int<1>, _1>;
  52. static constexpr int Stages = 2;
  53. using CollectiveMainloop = flash::CollectiveMainloopBwd<Stages, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
  54. Is_causal, Is_local, Varlen, Deterministic,
  55. dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>;
  56. using CollectiveEpilogue = flash::CollectiveEpilogueBwd<TileShape_MNK, Element, CollectiveMainloop::NumMmaThreads, Varlen>;
  57. using Scheduler = flash::SingleTileSchedulerBwd;
  58. using AttnKernel = flash::FlashAttnBwd<CollectiveMainloop, CollectiveEpilogue, Scheduler>;
  59. typename CollectiveMainloop::Arguments mainloop_args {
  60. static_cast<Element const*>(params.q_ptr),
  61. {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_Q
  62. {params.q_row_stride, _1{}, params.q_head_stride, !Varlen ? params.q_batch_stride : 0}, // stride_Q
  63. static_cast<Element const*>(params.k_ptr),
  64. {!Varlen ? params.seqlen_k : params.total_k, params.d, params.h_k, !Varlen ? params.b : 1}, // shape_K
  65. {params.k_row_stride, _1{}, params.k_head_stride, !Varlen ? params.k_batch_stride : 0}, // stride_K
  66. static_cast<Element const*>(params.v_ptr),
  67. {params.v_row_stride, _1{}, params.v_head_stride, !Varlen ? params.v_batch_stride : 0}, // stride_V
  68. static_cast<Element const*>(params.do_ptr),
  69. {params.do_row_stride, _1{}, params.do_head_stride, !Varlen ? params.do_batch_stride : 0}, // stride_dO
  70. static_cast<ElementAccum*>(params.dq_accum_ptr),
  71. // {params.seqlen_q_rounded, params.d_rounded, params.h, params.b}, // shape_dQaccum
  72. // {params.d_rounded, _1{}, params.d_rounded * params.seqlen_q_rounded, params.d_rounded * params.seqlen_q_rounded * params.h}, // stride_dQaccum
  73. {(!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded) * (params.d_rounded / 32), 32, params.h, !Varlen ? params.b : 1}, // shape_dQaccum
  74. {32, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
  75. static_cast<float*>(params.softmax_lse_log2_ptr),
  76. {!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, params.h, !Varlen ? params.b : 1}, // shape_LSE
  77. {_1{}, !Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, !Varlen ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
  78. static_cast<float*>(params.dsoftmax_sum),
  79. {_1{}, !Varlen ? params.seqlen_q_rounded : total_q_padded_rounded, !Varlen ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
  80. params.scale_softmax,
  81. params.b,
  82. params.dq_semaphore,
  83. params.cu_seqlens_q, params.cu_seqlens_k,
  84. params.seqused_q, params.seqused_k,
  85. params.window_size_left, params.window_size_right
  86. };
  87. typename CollectiveEpilogue::Arguments epilogue_args {
  88. static_cast<Element*>(params.dk_ptr),
  89. {!Varlen ? params.seqlen_k : params.total_k, params.d, params.h, !Varlen ? params.b : 1}, // shape_dK
  90. {params.dk_row_stride, _1{}, params.dk_head_stride, !Varlen ? params.dk_batch_stride : 0}, // stride_dK
  91. static_cast<Element*>(params.dv_ptr),
  92. {params.dv_row_stride, _1{}, params.dv_head_stride, !Varlen ? params.dv_batch_stride : 0},
  93. params.cu_seqlens_k
  94. };
  95. int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{}));
  96. num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{}));
  97. typename Scheduler::Arguments scheduler_args {
  98. num_blocks_n, params.h, params.b, params.tile_count_semaphore, params.cu_seqlens_k
  99. };
  100. int device;
  101. cudaGetDevice(&device);
  102. typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
  103. mainloop_args, epilogue_args, {device}, scheduler_args
  104. });
  105. // Get the ptr to kernel function.
  106. void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
  107. int smem_size = AttnKernel::SharedStorageSize;
  108. // int smem_size_q = sizeof(decltype((typename AttnKernel::SharedStorage{}).mainloop.smem_q));
  109. // int smem_size_do = sizeof(decltype((typename AttnKernel::SharedStorage{}).mainloop.smem_do));
  110. // int smem_size_ds = sizeof(decltype((typename AttnKernel::SharedStorage{}).mainloop.smem_ds));
  111. // int smem_size_dqacc = sizeof(decltype((typename AttnKernel::SharedStorage{}).mainloop.smem_dqacc));
  112. // int smem_size_k = sizeof(decltype((typename AttnKernel::SharedStorage{}).mainloop.smem_k));
  113. // int smem_size_v = sizeof(decltype((typename AttnKernel::SharedStorage{}).mainloop.smem_v));
  114. // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc);
  115. if (smem_size >= 48 * 1024) {
  116. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  117. }
  118. dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
  119. dim3 block_dims = AttnKernel::get_block_shape();
  120. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  121. cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
  122. cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
  123. CHECK_CUDA_KERNEL_LAUNCH();
  124. using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90,
  125. AttnKernel::CollectiveMainloop::kNThreadsdQ,
  126. typename AttnKernel::CollectiveMainloop::SmemLayoutdQaccumTMA,
  127. typename AttnKernel::CollectiveMainloop::TiledMmadQ,
  128. AttnKernel::CollectiveMainloop::dQ_swapAB
  129. >;
  130. typename PostprocessKernel::Arguments postprocess_args {
  131. static_cast<ElementAccum const*>(params.dq_accum_ptr),
  132. // {params.seqlen_q_rounded, params.d_rounded, params.h, params.b}, // shape_dQaccum
  133. // {params.d_rounded, _1{}, params.d_rounded * params.seqlen_q_rounded, params.d_rounded * params.seqlen_q_rounded * params.h}, // stride_dQaccum
  134. {(!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded) * (params.d_rounded / 32), 32, params.h, !Varlen ? params.b : 1}, // shape_dQaccum
  135. {32, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
  136. static_cast<Element*>(params.dq_ptr),
  137. {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_dQ
  138. {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
  139. params.scale_softmax,
  140. params.cu_seqlens_q,
  141. params.seqused_q
  142. };
  143. typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
  144. int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
  145. dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b);
  146. // Get the ptr to kernel function.
  147. auto postprocess_kernel = cutlass::device_kernel<PostprocessKernel>;
  148. int smem_size_postprocess = PostprocessKernel::SharedStorageSize;
  149. if (smem_size_postprocess >= 48 * 1024) {
  150. CHECK_CUDA(cudaFuncSetAttribute(postprocess_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  151. }
  152. postprocess_kernel<<<grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream>>>(postprocess_params);
  153. CHECK_CUDA_KERNEL_LAUNCH();
  154. }
  155. template<typename T>
  156. void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
  157. constexpr static int Headdim = 64;
  158. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  159. BOOL_SWITCH(params.is_local, Is_local, [&] {
  160. BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
  161. BOOL_SWITCH(params.deterministic, Deterministic, [&] {
  162. run_flash_bwd<Headdim, 128, 128, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 2>(params, stream);
  163. });
  164. });
  165. });
  166. });
  167. }
  168. template<typename T>
  169. void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
  170. constexpr static int Headdim = 96;
  171. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  172. BOOL_SWITCH(params.is_local, Is_local, [&] {
  173. BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
  174. BOOL_SWITCH(params.deterministic, Deterministic, [&] {
  175. run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
  176. });
  177. });
  178. });
  179. });
  180. }
  181. template<typename T>
  182. void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
  183. constexpr static int Headdim = 128;
  184. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  185. BOOL_SWITCH(params.is_local, Is_local, [&] {
  186. BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
  187. BOOL_SWITCH(params.deterministic, Deterministic, [&] {
  188. run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
  189. });
  190. });
  191. });
  192. });
  193. }