flash_bwd_launch_template.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cluster_launch.hpp"
  7. #include "static_switch.h"
  8. #include "flash.h"
  9. #include "flash_bwd_preprocess_kernel.h"
  10. #include "flash_bwd_kernel.h"
  11. #include "kernel_traits.h"
  12. #include "utils.h"
  13. template<bool Clear_dQaccum=true, typename Kernel_traits>
  14. __global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
  15. flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
  16. }
  17. // template<typename Kernel_traits>
  18. // __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
  19. // flash::convert_dQ<Kernel_traits>(params, nsplits);
  20. // }
  21. template<typename Kernel_traits>
  22. __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
  23. flash::convert_dKV<Kernel_traits>(params);
  24. }
  25. template<typename Kernel_traits, bool Is_causal>
  26. void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  27. int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
  28. dim3 grid_m(num_m_block, params.b, params.h);
  29. flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
  30. // If we use both TMA_STORE (for n_block=0) and TMA_REDUCE_ADD (for n_block>0), we don't need to clear dQaccum
  31. // flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
  32. CHECK_CUDA_KERNEL_LAUNCH();
  33. using Element = typename Kernel_traits::Element;
  34. using ElementAccum = typename Kernel_traits::ElementAccum;
  35. using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
  36. using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
  37. Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)),
  38. make_shape(params.seqlen_q, params.d, params.h, params.b),
  39. make_stride(params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride));
  40. auto tma_load_Q = make_tma_copy(
  41. typename Kernel_traits::GmemTiledCopyQdO{},
  42. mQ,
  43. typename Kernel_traits::SmemLayoutQ{}(_, _, _0{}),
  44. // typename Kernel_traits::SmemLayoutQ{},
  45. select<0, 2>(TileShape_MNK{}),
  46. size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
  47. Tensor mdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.do_ptr)),
  48. make_shape(params.seqlen_q, params.d, params.h, params.b),
  49. make_stride(params.do_row_stride, _1{}, params.do_head_stride, params.do_batch_stride));
  50. auto tma_load_dO = make_tma_copy(
  51. typename Kernel_traits::GmemTiledCopyQdO{},
  52. mdO,
  53. typename Kernel_traits::SmemLayoutdO{}(_, _, _0{}),
  54. // typename Kernel_traits::SmemLayoutdO{},
  55. select<0, 2>(TileShape_MNK{}),
  56. size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
  57. Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)),
  58. make_shape(params.seqlen_k, params.d, params.h, params.b),
  59. make_stride(params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride));
  60. auto tma_load_K = make_tma_copy(
  61. typename Kernel_traits::GmemTiledCopyKV{},
  62. mK,
  63. typename Kernel_traits::SmemLayoutK{},
  64. // typename Kernel_traits::SmemLayoutK{}(_, _, _0{}),
  65. select<1, 2>(TileShape_MNK{}),
  66. _1{}); // no mcast for K
  67. Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)),
  68. make_shape(params.seqlen_k, params.d, params.h, params.b),
  69. make_stride(params.v_row_stride, _1{}, params.v_head_stride, params.v_batch_stride));
  70. auto tma_load_V = make_tma_copy(
  71. typename Kernel_traits::GmemTiledCopyKV{},
  72. mV,
  73. typename Kernel_traits::SmemLayoutV{},
  74. // typename Kernel_traits::SmemLayoutV{}(_, _, _0{}),
  75. select<1, 2>(TileShape_MNK{}),
  76. _1{}); // no mcast for V
  77. Tensor mdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.dk_ptr)),
  78. make_shape(params.seqlen_k, params.d, params.h, params.b),
  79. make_stride(params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride));
  80. auto tma_store_dK = make_tma_copy(
  81. typename Kernel_traits::GmemTiledCopydKV{},
  82. mdK,
  83. typename Kernel_traits::SmemLayoutdK{},
  84. select<1, 2>(TileShape_MNK{}),
  85. _1{}); // no mcast for output
  86. Tensor mdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.dv_ptr)),
  87. make_shape(params.seqlen_k, params.d, params.h, params.b),
  88. make_stride(params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride));
  89. auto tma_store_dV = make_tma_copy(
  90. typename Kernel_traits::GmemTiledCopydKV{},
  91. mdV,
  92. typename Kernel_traits::SmemLayoutdV{},
  93. select<1, 2>(TileShape_MNK{}),
  94. _1{}); // no mcast for output
  95. Tensor mdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.dq_ptr)),
  96. make_shape(params.seqlen_q, params.d, params.h, params.b),
  97. make_stride(params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride));
  98. Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dq_accum_ptr)),
  99. make_shape(params.seqlen_q, params.d, params.h, params.b),
  100. make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_q_rounded));
  101. auto tma_store_dQaccum = make_tma_copy(
  102. // typename Kernel_traits::GmemTiledCopydKV{},
  103. typename cute::SM90_TMA_STORE{},
  104. // mdQ,
  105. mdQaccum,
  106. // typename Kernel_traits::SmemLayoutdQTMA{},
  107. typename Kernel_traits::SmemLayoutdQaccTMA{},
  108. select<0, 2>(TileShape_MNK{}),
  109. _1{}); // no mcast for output
  110. auto tma_reduce_add_dQaccum = make_tma_copy(
  111. // typename Kernel_traits::GmemTiledCopydKV{},
  112. typename cute::SM90_TMA_REDUCE_ADD{},
  113. // mdQ,
  114. mdQaccum,
  115. // typename Kernel_traits::SmemLayoutdQTMA{},
  116. typename Kernel_traits::SmemLayoutdQaccTMA{},
  117. select<0, 2>(TileShape_MNK{}),
  118. _1{}); // no mcast for output
  119. // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
  120. // print(typename Kernel_traits::TiledMmaSdP{}); printf("\n");
  121. // print(typename Kernel_traits::TiledMmadKV{}); printf("\n");
  122. // print(typename Kernel_traits::TiledMmadQ{}); printf("\n");
  123. // print(typename Kernel_traits::SmemLayoutAtomK{}); printf("\n");
  124. // print(typename Kernel_traits::SmemLayoutK{}); printf("\n");
  125. // print(typename Kernel_traits::SmemLayoutKt{}); printf("\n");
  126. // Get the ptr to kernel function.
  127. void *kernel;
  128. if constexpr (!Kernel_traits::Is_WS) {
  129. kernel = (void *)flash::compute_dqkv<Kernel_traits, Is_causal, decltype(tma_load_Q), decltype(tma_load_dO),
  130. decltype(tma_load_K), decltype(tma_load_V), decltype(tma_store_dK), decltype(tma_store_dV)>;
  131. } else {
  132. kernel = (void *)flash::compute_dqkv_ws<Kernel_traits, Is_causal, decltype(tma_load_Q), decltype(tma_load_dO),
  133. decltype(tma_load_K), decltype(tma_load_V), decltype(tma_store_dK), decltype(tma_store_dV), decltype(tma_store_dQaccum), decltype(tma_reduce_add_dQaccum)>;
  134. }
  135. // void *kernel = (void *)flash::compute_dqkv_seqqpar<Kernel_traits, Is_causal, decltype(tma_load_Q), decltype(tma_load_dO),
  136. // decltype(tma_load_K), decltype(tma_load_V), decltype(tma_store_dQaccum), decltype(tma_store_dK), decltype(tma_store_dV)>;
  137. auto shared_storage = typename Kernel_traits::SharedStorage{};
  138. int smem_size = sizeof(typename Kernel_traits::SharedStorage);
  139. int smem_size_q = sizeof(decltype(shared_storage.smem_q));
  140. int smem_size_do = sizeof(decltype(shared_storage.smem_do));
  141. int smem_size_k = sizeof(decltype(shared_storage.smem_k));
  142. int smem_size_v = sizeof(decltype(shared_storage.smem_v));
  143. // int smem_size_p = sizeof(decltype(shared_storage.smem_p));
  144. int smem_size_ds = sizeof(decltype(shared_storage.smem_ds));
  145. // printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, p = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_p, smem_size_ds);
  146. // printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_ds);
  147. if (smem_size >= 48 * 1024) {
  148. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  149. }
  150. static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
  151. int num_blocks_n = cutlass::ceil_div(params.seqlen_k, Kernel_traits::kBlockN);
  152. num_blocks_n = cutlass::ceil_div(num_blocks_n, size<1>(ClusterShape{})) * size<1>(ClusterShape{});
  153. dim3 grid_dims(num_blocks_n, params.h, params.b);
  154. // int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
  155. // num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
  156. // dim3 grid_dims(num_blocks_m, params.h, params.b);
  157. dim3 block_dims(ctaSize);
  158. dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
  159. cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
  160. if constexpr (!Kernel_traits::Is_WS) {
  161. cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
  162. tma_load_K, tma_load_V, tma_store_dK, tma_store_dV);
  163. } else {
  164. cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
  165. tma_load_K, tma_load_V, tma_store_dK, tma_store_dV, tma_store_dQaccum, tma_reduce_add_dQaccum);
  166. }
  167. // cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
  168. // tma_load_K, tma_load_V, tma_store_dQaccum, tma_store_dK, tma_store_dV);
  169. CHECK_CUDA_KERNEL_LAUNCH();
  170. auto tma_load_dQaccum = make_tma_copy(
  171. typename cute::SM90_TMA_LOAD{},
  172. mdQaccum,
  173. typename Kernel_traits::SmemLayoutdQaccTMA{},
  174. select<0, 2>(TileShape_MNK{}),
  175. _1{}); // no mcast for output
  176. // auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
  177. auto kernel_dq = &flash::convert_dQ<Kernel_traits, decltype(tma_load_dQaccum)>;
  178. if (Kernel_traits::kSmemdQSize * 2 + 8 >= 48 * 1024) {
  179. CHECK_CUDA(cudaFuncSetAttribute(
  180. kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize * 2 + 8));
  181. }
  182. kernel_dq<<<grid_m, Kernel_traits::kNThreadsdQ, Kernel_traits::kSmemdQSize * 2 + 8, stream>>>(params, tma_load_dQaccum);
  183. CHECK_CUDA_KERNEL_LAUNCH();
  184. // auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
  185. // if (Kernel_traits::kSmemdKVSize >= 48 * 1024) {
  186. // CHECK_CUDA(cudaFuncSetAttribute(
  187. // kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdKVSize));
  188. // }
  189. // int num_n_block = cute::ceil_div(params.seqlen_k, Kernel_traits::kBlockN);
  190. // dim3 grid_n(num_n_block, params.b, params.h);
  191. // kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemdKVSize, stream>>>(params);
  192. // CHECK_CUDA_KERNEL_LAUNCH();
  193. }
  194. template<typename T>
  195. void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
  196. constexpr static int Headdim = 64;
  197. // BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  198. // run_flash_bwd<T, Headdim, Is_causal>(params, stream);
  199. // });
  200. // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, false, false, false, 2, 2, 2, 1, T>, false>(params, stream);
  201. run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 12, true, false, false, 1, 2, 2, 1, T>, false>(params, stream);
  202. // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 96, 128, 12, true, false, true, 1, 2, 2, 1, T>, false>(params, stream);
  203. }
  204. template<typename T>
  205. void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
  206. constexpr static int Headdim = 128;
  207. // BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  208. // run_flash_bwd<T, Headdim, Is_causal>(params, stream);
  209. // });
  210. // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, 2, 1, 2, 1, T>, false>(params, stream);
  211. // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, false, false, false, 1, 2, 1, 1, T>, false>(params, stream);
  212. // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 96, 8, false, true, false, 2, 1, 2, 1, T>, false>(params, stream);
  213. // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 96, 8, false, true, true, 2, 1, 1, 1, T>, false>(params, stream);
  214. // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, true, false, true, 1, 2, 1, 1, T>, false>(params, stream);
  215. run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 12, true, false, true, 1, 2, 1, 1, T>, false>(params, stream);
  216. // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 12, true, false, false, 1, 2, 1, 1, T>, false>(params, stream);
  217. // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 12, false, false, false, 1, 2, 1, 1, T>, false>(params, stream);
  218. // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 80, 128, 12, true, false, true, 1, 2, 1, 1, T>, false>(params, stream);
  219. // run_flash_bwd<Flash_bwd_seqqpar_kernel_traits<Headdim, 128, 64, 8, false, true, false, 2, 1, 2, 1, T>, false>(params, stream);
  220. // run_flash_bwd<Flash_bwd_seqqpar_kernel_traits<Headdim, 96, 128, 8, true, false, true, 1, 2, 1, 1, T>, false>(params, stream);
  221. }
  222. template<typename T>
  223. void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
  224. // constexpr static int Headdim = 256;
  225. // BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  226. // run_flash_bwd<T, Headdim, Is_causal>(params, stream);
  227. // });
  228. }