flash_bwd_preprocess_kernel.h 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. /***************************************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cute/tensor.hpp>
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/array.h>
  8. #include <cutlass/numeric_types.h>
  9. #include "cutlass/pipeline/pipeline.hpp"
  10. #include "flash.h"
  11. #include "block_info.h"
  12. #include "kernel_traits.h"
  13. #include "utils.h"
  14. namespace flash {
  15. using namespace cute;
  16. ////////////////////////////////////////////////////////////////////////////////////////////////////
  17. template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  18. inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
  19. Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
  20. static_assert(Layout0::rank == 3, "Only support 3D Tensor");
  21. static_assert(Layout1::rank == 1, "Only support 1D Tensor");
  22. CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
  23. // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64)
  24. // The last coordinate is the "page".
  25. Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()),
  26. make_layout(get<0>(do_.layout()),
  27. get<2>(do_.layout()))));
  28. Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());
  29. Tensor do_fp32 = flash::convert_type<float>(do_reshaped);
  30. Tensor o_fp32 = flash::convert_type<float>(o_reshaped);
  31. #pragma unroll
  32. for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {
  33. float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
  34. #pragma unroll
  35. for (int ni = 1; ni < size<1>(do_reshaped); ni++) {
  36. dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
  37. }
  38. flash::SumOp<float> sum_op;
  39. dP_sum_cur = flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
  40. if (threadIdx.x % THREADS_PER_ROW == 0) {
  41. dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
  42. }
  43. }
  44. }
  45. ////////////////////////////////////////////////////////////////////////////////////////////////////
  46. // Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
  47. // This is used in the case where we want to parallelize the backward across seqlen_k.
  48. template<bool Clear_dQaccum=true, typename Kernel_traits, typename Params>
  49. inline __device__ void compute_dot_do_o(const Params &params) {
  50. using Element = typename Kernel_traits::Element;
  51. using ElementAccum = typename Kernel_traits::ElementAccum;
  52. using index_t = typename Kernel_traits::index_t;
  53. const int m_block = blockIdx.x;
  54. // The block index for the batch.
  55. const int bidb = blockIdx.y;
  56. // The block index for the head.
  57. const int bidh = blockIdx.z;
  58. // The thread index.
  59. const int tidx = threadIdx.x;
  60. constexpr int kBlockM = Kernel_traits::kBlockM;
  61. constexpr int kHeadDim = Kernel_traits::kHeadDim;
  62. const BlockInfo binfo(params, bidb);
  63. if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
  64. const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
  65. + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
  66. const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
  67. + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
  68. const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
  69. + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
  70. const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
  71. Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
  72. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  73. make_stride(params.do_row_stride, _1{}));
  74. Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
  75. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  76. make_stride(params.o_row_stride, _1{}));
  77. Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
  78. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  79. make_stride(params.h * params.d_rounded, _1{}));
  80. Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
  81. Shape<Int<kBlockM>>{}, Stride<_1>{});
  82. typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
  83. auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
  84. // TODO: careful, we're zeroing out dQaccum with type float4, but when
  85. // we do atomicAdds, we use type float. The layouts are different. Check this.
  86. typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
  87. auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
  88. Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
  89. Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
  90. Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
  91. Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  92. Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);
  93. // Allocate predicate tensors for k
  94. Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO)));
  95. // Set predicates for k bounds
  96. #pragma unroll
  97. for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;}
  98. Tensor tdOrdO = make_fragment_like(tdOgdO);
  99. Tensor tdOrO = make_fragment_like(tdOgO);
  100. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
  101. gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
  102. );
  103. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
  104. gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
  105. );
  106. // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
  107. // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
  108. // so that (dP - dP_sum) is on the same scale.
  109. dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
  110. // Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
  111. Kernel_traits::kNThreadsNonWS / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
  112. if (Clear_dQaccum) {
  113. // We're actually not zero'ing out all of dQaccum, but only the part that we're going to
  114. // do atomicAdds on.
  115. Tensor zero = make_fragment_like(tdQgdQaccum);
  116. clear(zero);
  117. cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
  118. }
  119. }
  120. ////////////////////////////////////////////////////////////////////////////////////////////////////
  121. template<typename Kernel_traits, typename Params>
  122. inline __device__ void clear_dKVaccum(const Params &params) {
  123. using ElementAccum = typename Kernel_traits::ElementAccum;
  124. using index_t = typename Kernel_traits::index_t;
  125. const int n_block = blockIdx.x;
  126. // The block index for the batch.
  127. const int bidb = blockIdx.y;
  128. // The block index for the head.
  129. const int bidh = blockIdx.z;
  130. // The thread index.
  131. const int tidx = threadIdx.x;
  132. constexpr int kBlockN = Kernel_traits::kBlockN;
  133. constexpr int kHeadDim = Kernel_traits::kHeadDim;
  134. const BlockInfo binfo(params, bidb);
  135. if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
  136. const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded;
  137. Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
  138. Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
  139. Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
  140. Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
  141. typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
  142. auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
  143. Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
  144. Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
  145. Tensor zero = make_fragment_like(tdKgdKaccum);
  146. clear(zero);
  147. cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);
  148. cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);
  149. }
  150. ////////////////////////////////////////////////////////////////////////////////////////////////////
  151. // Convert dQ from dQaccum (in float) to fp16/bf16.
  152. // This is used in the case where we want to parallelize the backward across seqlen_k.
  153. // template<typename Kernel_traits, typename Params, typename TiledCopydQaccum>
  154. template<typename Kernel_traits, typename TiledCopydQaccum>
  155. // inline __device__ void convert_dQ(const Params &params,
  156. __global__ void convert_dQ(CUTE_GRID_CONSTANT Flash_bwd_params const params,
  157. CUTE_GRID_CONSTANT TiledCopydQaccum const tma_load_dQaccum) {
  158. using Element = typename Kernel_traits::Element;
  159. using ElementAccum = typename Kernel_traits::ElementAccum;
  160. using index_t = typename Kernel_traits::index_t;
  161. // Shared memory.
  162. extern __shared__ char smem_[];
  163. int lane_predicate = cute::elect_one_sync();
  164. int warp_idx = cutlass::canonical_warp_idx_sync();
  165. // Issue Tma Descriptor Prefetch from a single thread
  166. if (warp_idx == 0 && lane_predicate) {
  167. cute::prefetch_tma_descriptor(tma_load_dQaccum.get_tma_descriptor());
  168. }
  169. const int m_block = blockIdx.x;
  170. // The block index for the batch.
  171. const int bidb = blockIdx.y;
  172. // The block index for the head.
  173. const int bidh = blockIdx.z;
  174. // The thread index.
  175. const int tidx = threadIdx.x;
  176. constexpr int kBlockM = Kernel_traits::kBlockM;
  177. constexpr int kHeadDim = Kernel_traits::kHeadDim;
  178. static constexpr bool dQ_swapAB = Kernel_traits::dQ_swapAB;
  179. Tensor mdQaccum = tma_load_dQaccum.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
  180. Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), Shape<Int<kBlockM>, Int<kHeadDim>>{}, make_coord(m_block, _0{})); // (M, K)
  181. const BlockInfo binfo(params, bidb);
  182. if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
  183. const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
  184. + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
  185. const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
  186. + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
  187. Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
  188. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  189. make_stride(params.dq_row_stride, _1{}));
  190. // Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
  191. // Shape<Int<kBlockM>, Int<kHeadDim>>{},
  192. // make_stride(params.h * params.d_rounded, _1{}));
  193. Tensor sdQTMA = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(smem_)),
  194. typename Kernel_traits::SmemLayoutdQaccTMA{});
  195. Tensor sdQaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(smem_)),
  196. typename Kernel_traits::SmemLayoutdQacc{});
  197. Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
  198. typename Kernel_traits::SmemLayoutdQ{});
  199. Tensor sdQt = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
  200. typename Kernel_traits::SmemLayoutdQt{});
  201. auto &barrier_dQaccum = *reinterpret_cast<cutlass::arch::ClusterTransactionBarrier*>(smem_ + sizeof(ElementAccum) * size(sdQTMA));
  202. auto block_tma_dQ = tma_load_dQaccum.get_slice(_0{});
  203. Tensor tdQgdQaccumTMA = block_tma_dQ.partition_S(gdQaccum); // (TMA, TMA_M, TMA_K)
  204. Tensor tdQsdQaccumTMA = block_tma_dQ.partition_D(sdQTMA); // (TMA, TMA_M, TMA_K)
  205. typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
  206. auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
  207. // typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;
  208. // typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
  209. // auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
  210. typename Kernel_traits::TiledMmadQ tiled_mma_dq;
  211. auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
  212. auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
  213. Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  214. Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
  215. // Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);
  216. constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size<0>(sdQTMA) * size<1>(sdQTMA) * cutlass::sizeof_bits_v<ElementAccum> / 8);
  217. if (warp_idx == 0 && lane_predicate) {
  218. barrier_dQaccum.init(1 /*numThreads*/);
  219. }
  220. __syncthreads();
  221. if (warp_idx == 0 && lane_predicate) {
  222. barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
  223. copy(tma_load_dQaccum.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(barrier_dQaccum), 0 /*mcast_mask*/), tdQgdQaccumTMA, tdQsdQaccumTMA);
  224. }
  225. barrier_dQaccum.wait(0);
  226. // if (cute::thread0()) { print_tensor(sdQTMA); printf("\n"); }
  227. typename Kernel_traits::RmemTiledCopydQacc rmem_tiled_copy_dQaccum;
  228. auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x);
  229. Tensor tdQsdQaccum = rmem_thr_copy_dQaccum.partition_S(sdQaccum);
  230. Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<!dQ_swapAB ? kBlockM : kHeadDim>, Int<!dQ_swapAB ? kHeadDim : kBlockM>>{}); // MMA, MMA_N, MMA_K
  231. CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQsdQaccum));
  232. Tensor tdQrdQaccum = rmem_thr_copy_dQaccum.retile_D(acc_dq);
  233. cute::copy(rmem_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
  234. // Tensor dQ_tmp = make_tensor(acc_dq.data(), flash::convert_layout_acc_rowcol(acc_dq.layout()));
  235. // if (blockIdx.x == 0 && threadIdx.x == 0) { print_tensor(dQ_tmp); printf("\n"); }
  236. #pragma unroll
  237. for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
  238. // Convert acc_dq from fp32 to fp16
  239. Tensor rdQ = flash::convert_type<Element>(acc_dq);
  240. Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
  241. // dQacc and dQ uses the same shared memory, need to wait for all threads to finish reading smem first
  242. __syncthreads();
  243. if constexpr (!dQ_swapAB) {
  244. Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  245. cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
  246. } else {
  247. Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  248. cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt);
  249. }
  250. __syncthreads();
  251. Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
  252. cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
  253. Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  254. Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
  255. Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
  256. #pragma unroll
  257. for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
  258. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  259. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  260. gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
  261. );
  262. }
  263. ////////////////////////////////////////////////////////////////////////////////////////////////////
  264. // Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16.
  265. // This is used in the case where we want to parallelize the backward across seqlen_q.
  266. template<typename Kernel_traits, typename Params>
  267. inline __device__ void convert_dKV(const Params &params) {
  268. using Element = typename Kernel_traits::Element;
  269. using ElementAccum = typename Kernel_traits::ElementAccum;
  270. using index_t = typename Kernel_traits::index_t;
  271. // Shared memory.
  272. extern __shared__ char smem_[];
  273. const int n_block = blockIdx.x;
  274. // The block index for the batch.
  275. const int bidb = blockIdx.y;
  276. // The block index for the head.
  277. const int bidh = blockIdx.z;
  278. // The thread index.
  279. const int tidx = threadIdx.x;
  280. constexpr int kBlockN = Kernel_traits::kBlockN;
  281. constexpr int kHeadDim = Kernel_traits::kHeadDim;
  282. static constexpr bool dKV_swapAB = Kernel_traits::dKV_swapAB;
  283. const BlockInfo binfo(params, bidb);
  284. if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
  285. const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
  286. + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
  287. const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
  288. + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
  289. const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded
  290. + n_block * kBlockN) * params.d_rounded;
  291. Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
  292. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  293. make_stride(params.dk_row_stride, _1{}));
  294. Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
  295. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  296. make_stride(params.dv_row_stride, _1{}));
  297. Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
  298. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  299. Stride<Int<kHeadDim>, _1>{});
  300. Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
  301. Shape<Int<kBlockN>, Int<kHeadDim>>{},
  302. Stride<Int<kHeadDim>, _1>{});
  303. Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
  304. typename Kernel_traits::SmemLayoutdKV{});
  305. Tensor sdKt = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
  306. typename Kernel_traits::SmemLayoutdKVt{});
  307. Tensor sdV = make_tensor(sdK.data() + size(sdK),
  308. typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
  309. Tensor sdVt = make_tensor(make_smem_ptr(sdK.data() + size(sdK)),
  310. typename Kernel_traits::SmemLayoutdKVt{});
  311. typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;
  312. auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
  313. // typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
  314. typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
  315. auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
  316. typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
  317. auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
  318. auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
  319. Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  320. Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
  321. Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
  322. Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
  323. Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);
  324. Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);
  325. Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<!dKV_swapAB ? kBlockN : kHeadDim>, Int<!dKV_swapAB ? kHeadDim : kBlockN>>{}); // MMA, MMA_N, MMA_K
  326. Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<!dKV_swapAB ? kBlockN : kHeadDim>, Int<!dKV_swapAB ? kHeadDim : kBlockN>>{}); // MMA, MMA_N, MMA_K
  327. CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum));
  328. CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum));
  329. Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);
  330. Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);
  331. cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);
  332. cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);
  333. #pragma unroll
  334. for (int i = 0; i < size(acc_dk); ++i) {
  335. acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;
  336. }
  337. #pragma unroll
  338. for (int i = 0; i < size(acc_dv); ++i) {
  339. acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;
  340. }
  341. // Convert acc_dk from fp32 to fp16
  342. Tensor rdK = flash::convert_type<Element>(acc_dk);
  343. Tensor rdV = flash::convert_type<Element>(acc_dv);
  344. Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
  345. Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
  346. if constexpr (!dKV_swapAB) {
  347. Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  348. Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  349. cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
  350. cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
  351. } else {
  352. Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  353. Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  354. cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt);
  355. cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt);
  356. }
  357. __syncthreads();
  358. Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
  359. Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
  360. cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
  361. cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
  362. // if (cute::thread0()) { print_tensor(tdKrdK); printf("\n"); }
  363. // if (cute::thread0()) { print_tensor(tdVrdV); printf("\n"); }
  364. Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  365. Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
  366. Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
  367. #pragma unroll
  368. for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
  369. // Clear_OOB_K must be false since we don't want to write zeros to gmem
  370. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  371. gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
  372. );
  373. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
  374. gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
  375. );
  376. }
  377. } // namespace flash