123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444 |
- /***************************************************************************************************
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
- ******************************************************************************/
- #pragma once
- #include <cute/tensor.hpp>
- #include <cutlass/cutlass.h>
- #include <cutlass/array.h>
- #include <cutlass/numeric_types.h>
- #include "cutlass/pipeline/pipeline.hpp"
- #include "flash.h"
- #include "block_info.h"
- #include "kernel_traits.h"
- #include "utils.h"
- namespace flash {
- using namespace cute;
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
- inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
- Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
- static_assert(Layout0::rank == 3, "Only support 3D Tensor");
- static_assert(Layout1::rank == 1, "Only support 1D Tensor");
- CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
- // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64)
- // The last coordinate is the "page".
- Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()),
- make_layout(get<0>(do_.layout()),
- get<2>(do_.layout()))));
- Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());
- Tensor do_fp32 = flash::convert_type<float>(do_reshaped);
- Tensor o_fp32 = flash::convert_type<float>(o_reshaped);
- #pragma unroll
- for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {
- float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
- #pragma unroll
- for (int ni = 1; ni < size<1>(do_reshaped); ni++) {
- dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
- }
- flash::SumOp<float> sum_op;
- dP_sum_cur = flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
- if (threadIdx.x % THREADS_PER_ROW == 0) {
- dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
- }
- }
- }
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- // Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
- // This is used in the case where we want to parallelize the backward across seqlen_k.
- template<bool Clear_dQaccum=true, typename Kernel_traits, typename Params>
- inline __device__ void compute_dot_do_o(const Params ¶ms) {
- using Element = typename Kernel_traits::Element;
- using ElementAccum = typename Kernel_traits::ElementAccum;
- using index_t = typename Kernel_traits::index_t;
- const int m_block = blockIdx.x;
- // The block index for the batch.
- const int bidb = blockIdx.y;
- // The block index for the head.
- const int bidh = blockIdx.z;
- // The thread index.
- const int tidx = threadIdx.x;
- constexpr int kBlockM = Kernel_traits::kBlockM;
- constexpr int kHeadDim = Kernel_traits::kHeadDim;
- const BlockInfo binfo(params, bidb);
- if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
- const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
- + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
- const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
- + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
- const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
- + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
- const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
- Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
- Shape<Int<kBlockM>, Int<kHeadDim>>{},
- make_stride(params.do_row_stride, _1{}));
- Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
- Shape<Int<kBlockM>, Int<kHeadDim>>{},
- make_stride(params.o_row_stride, _1{}));
- Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
- Shape<Int<kBlockM>, Int<kHeadDim>>{},
- make_stride(params.h * params.d_rounded, _1{}));
- Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
- Shape<Int<kBlockM>>{}, Stride<_1>{});
- typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
- auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
- // TODO: careful, we're zeroing out dQaccum with type float4, but when
- // we do atomicAdds, we use type float. The layouts are different. Check this.
- typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
- auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
- Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
- Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
- Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
- Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
- Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);
- // Allocate predicate tensors for k
- Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO)));
- // Set predicates for k bounds
- #pragma unroll
- for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;}
- Tensor tdOrdO = make_fragment_like(tdOgdO);
- Tensor tdOrO = make_fragment_like(tdOgO);
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
- gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
- );
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
- gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
- );
- // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
- // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
- // so that (dP - dP_sum) is on the same scale.
- dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
- // Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
- Kernel_traits::kNThreadsNonWS / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
- if (Clear_dQaccum) {
- // We're actually not zero'ing out all of dQaccum, but only the part that we're going to
- // do atomicAdds on.
- Tensor zero = make_fragment_like(tdQgdQaccum);
- clear(zero);
- cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
- }
- }
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<typename Kernel_traits, typename Params>
- inline __device__ void clear_dKVaccum(const Params ¶ms) {
- using ElementAccum = typename Kernel_traits::ElementAccum;
- using index_t = typename Kernel_traits::index_t;
- const int n_block = blockIdx.x;
- // The block index for the batch.
- const int bidb = blockIdx.y;
- // The block index for the head.
- const int bidh = blockIdx.z;
- // The thread index.
- const int tidx = threadIdx.x;
- constexpr int kBlockN = Kernel_traits::kBlockN;
- constexpr int kHeadDim = Kernel_traits::kHeadDim;
- const BlockInfo binfo(params, bidb);
- if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
- const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded;
- Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
- Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
- Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
- Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
- typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
- auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
- Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
- Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
- Tensor zero = make_fragment_like(tdKgdKaccum);
- clear(zero);
- cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);
- cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);
- }
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- // Convert dQ from dQaccum (in float) to fp16/bf16.
- // This is used in the case where we want to parallelize the backward across seqlen_k.
- // template<typename Kernel_traits, typename Params, typename TiledCopydQaccum>
- template<typename Kernel_traits, typename TiledCopydQaccum>
- // inline __device__ void convert_dQ(const Params ¶ms,
- __global__ void convert_dQ(CUTE_GRID_CONSTANT Flash_bwd_params const params,
- CUTE_GRID_CONSTANT TiledCopydQaccum const tma_load_dQaccum) {
- using Element = typename Kernel_traits::Element;
- using ElementAccum = typename Kernel_traits::ElementAccum;
- using index_t = typename Kernel_traits::index_t;
- // Shared memory.
- extern __shared__ char smem_[];
- int lane_predicate = cute::elect_one_sync();
- int warp_idx = cutlass::canonical_warp_idx_sync();
- // Issue Tma Descriptor Prefetch from a single thread
- if (warp_idx == 0 && lane_predicate) {
- cute::prefetch_tma_descriptor(tma_load_dQaccum.get_tma_descriptor());
- }
- const int m_block = blockIdx.x;
- // The block index for the batch.
- const int bidb = blockIdx.y;
- // The block index for the head.
- const int bidh = blockIdx.z;
- // The thread index.
- const int tidx = threadIdx.x;
- constexpr int kBlockM = Kernel_traits::kBlockM;
- constexpr int kHeadDim = Kernel_traits::kHeadDim;
- static constexpr bool dQ_swapAB = Kernel_traits::dQ_swapAB;
- Tensor mdQaccum = tma_load_dQaccum.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
- Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), Shape<Int<kBlockM>, Int<kHeadDim>>{}, make_coord(m_block, _0{})); // (M, K)
- const BlockInfo binfo(params, bidb);
- if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
- const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
- + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
- const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
- + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
- Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
- Shape<Int<kBlockM>, Int<kHeadDim>>{},
- make_stride(params.dq_row_stride, _1{}));
- // Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
- // Shape<Int<kBlockM>, Int<kHeadDim>>{},
- // make_stride(params.h * params.d_rounded, _1{}));
- Tensor sdQTMA = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(smem_)),
- typename Kernel_traits::SmemLayoutdQaccTMA{});
- Tensor sdQaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(smem_)),
- typename Kernel_traits::SmemLayoutdQacc{});
- Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
- typename Kernel_traits::SmemLayoutdQ{});
- Tensor sdQt = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
- typename Kernel_traits::SmemLayoutdQt{});
- auto &barrier_dQaccum = *reinterpret_cast<cutlass::arch::ClusterTransactionBarrier*>(smem_ + sizeof(ElementAccum) * size(sdQTMA));
- auto block_tma_dQ = tma_load_dQaccum.get_slice(_0{});
- Tensor tdQgdQaccumTMA = block_tma_dQ.partition_S(gdQaccum); // (TMA, TMA_M, TMA_K)
- Tensor tdQsdQaccumTMA = block_tma_dQ.partition_D(sdQTMA); // (TMA, TMA_M, TMA_K)
- typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
- auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
- // typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;
- // typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
- // auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
- typename Kernel_traits::TiledMmadQ tiled_mma_dq;
- auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
- auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
- Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
- Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
- // Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);
- constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size<0>(sdQTMA) * size<1>(sdQTMA) * cutlass::sizeof_bits_v<ElementAccum> / 8);
- if (warp_idx == 0 && lane_predicate) {
- barrier_dQaccum.init(1 /*numThreads*/);
- }
- __syncthreads();
- if (warp_idx == 0 && lane_predicate) {
- barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
- copy(tma_load_dQaccum.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(barrier_dQaccum), 0 /*mcast_mask*/), tdQgdQaccumTMA, tdQsdQaccumTMA);
- }
- barrier_dQaccum.wait(0);
- // if (cute::thread0()) { print_tensor(sdQTMA); printf("\n"); }
- typename Kernel_traits::RmemTiledCopydQacc rmem_tiled_copy_dQaccum;
- auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x);
- Tensor tdQsdQaccum = rmem_thr_copy_dQaccum.partition_S(sdQaccum);
- Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<!dQ_swapAB ? kBlockM : kHeadDim>, Int<!dQ_swapAB ? kHeadDim : kBlockM>>{}); // MMA, MMA_N, MMA_K
- CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQsdQaccum));
- Tensor tdQrdQaccum = rmem_thr_copy_dQaccum.retile_D(acc_dq);
- cute::copy(rmem_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
- // Tensor dQ_tmp = make_tensor(acc_dq.data(), flash::convert_layout_acc_rowcol(acc_dq.layout()));
- // if (blockIdx.x == 0 && threadIdx.x == 0) { print_tensor(dQ_tmp); printf("\n"); }
- #pragma unroll
- for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
- // Convert acc_dq from fp32 to fp16
- Tensor rdQ = flash::convert_type<Element>(acc_dq);
- Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
- // dQacc and dQ uses the same shared memory, need to wait for all threads to finish reading smem first
- __syncthreads();
- if constexpr (!dQ_swapAB) {
- Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
- cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
- } else {
- Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
- cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt);
- }
- __syncthreads();
- Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
- cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
- Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
- Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
- Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
- #pragma unroll
- for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
- gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
- );
- }
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- // Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16.
- // This is used in the case where we want to parallelize the backward across seqlen_q.
- template<typename Kernel_traits, typename Params>
- inline __device__ void convert_dKV(const Params ¶ms) {
- using Element = typename Kernel_traits::Element;
- using ElementAccum = typename Kernel_traits::ElementAccum;
- using index_t = typename Kernel_traits::index_t;
- // Shared memory.
- extern __shared__ char smem_[];
- const int n_block = blockIdx.x;
- // The block index for the batch.
- const int bidb = blockIdx.y;
- // The block index for the head.
- const int bidh = blockIdx.z;
- // The thread index.
- const int tidx = threadIdx.x;
- constexpr int kBlockN = Kernel_traits::kBlockN;
- constexpr int kHeadDim = Kernel_traits::kHeadDim;
- static constexpr bool dKV_swapAB = Kernel_traits::dKV_swapAB;
- const BlockInfo binfo(params, bidb);
- if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
- const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
- + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
- const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
- + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
- const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded
- + n_block * kBlockN) * params.d_rounded;
- Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
- Shape<Int<kBlockN>, Int<kHeadDim>>{},
- make_stride(params.dk_row_stride, _1{}));
- Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
- Shape<Int<kBlockN>, Int<kHeadDim>>{},
- make_stride(params.dv_row_stride, _1{}));
- Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
- Shape<Int<kBlockN>, Int<kHeadDim>>{},
- Stride<Int<kHeadDim>, _1>{});
- Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
- Shape<Int<kBlockN>, Int<kHeadDim>>{},
- Stride<Int<kHeadDim>, _1>{});
- Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
- typename Kernel_traits::SmemLayoutdKV{});
- Tensor sdKt = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
- typename Kernel_traits::SmemLayoutdKVt{});
- Tensor sdV = make_tensor(sdK.data() + size(sdK),
- typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
- Tensor sdVt = make_tensor(make_smem_ptr(sdK.data() + size(sdK)),
- typename Kernel_traits::SmemLayoutdKVt{});
- typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;
- auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
- // typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
- typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
- auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
- typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
- auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
- auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
- Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
- Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
- Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
- Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
- Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);
- Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);
- Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<!dKV_swapAB ? kBlockN : kHeadDim>, Int<!dKV_swapAB ? kHeadDim : kBlockN>>{}); // MMA, MMA_N, MMA_K
- Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<!dKV_swapAB ? kBlockN : kHeadDim>, Int<!dKV_swapAB ? kHeadDim : kBlockN>>{}); // MMA, MMA_N, MMA_K
- CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum));
- CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum));
- Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);
- Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);
- cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);
- cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);
- #pragma unroll
- for (int i = 0; i < size(acc_dk); ++i) {
- acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;
- }
- #pragma unroll
- for (int i = 0; i < size(acc_dv); ++i) {
- acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;
- }
- // Convert acc_dk from fp32 to fp16
- Tensor rdK = flash::convert_type<Element>(acc_dk);
- Tensor rdV = flash::convert_type<Element>(acc_dv);
- Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
- Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
- if constexpr (!dKV_swapAB) {
- Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
- Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
- cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
- cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
- } else {
- Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
- Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
- cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt);
- cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt);
- }
- __syncthreads();
- Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
- Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
- cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
- cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
- // if (cute::thread0()) { print_tensor(tdKrdK); printf("\n"); }
- // if (cute::thread0()) { print_tensor(tdVrdV); printf("\n"); }
- Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
- Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
- Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
- #pragma unroll
- for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
- gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
- );
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
- gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
- );
- }
- } // namespace flash
|