/*************************************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #include "cutlass/pipeline/pipeline.hpp" #include "flash.h" #include "block_info.h" #include "kernel_traits.h" #include "utils.h" namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, Tensor &dP_sum, const int gdP_col_stride, const float scale) { static_assert(Layout0::rank == 3, "Only support 3D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64) // The last coordinate is the "page". Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), make_layout(get<0>(do_.layout()), get<2>(do_.layout())))); Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); Tensor do_fp32 = flash::convert_type(do_reshaped); Tensor o_fp32 = flash::convert_type(o_reshaped); #pragma unroll for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); #pragma unroll for (int ni = 1; ni < size<1>(do_reshaped); ni++) { dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); } flash::SumOp sum_op; dP_sum_cur = flash::Allreduce::run(dP_sum_cur, sum_op) * scale; if (threadIdx.x % THREADS_PER_ROW == 0) { dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. // This is used in the case where we want to parallelize the backward across seqlen_k. template inline __device__ void compute_dot_do_o(const Params ¶ms) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; const int m_block = blockIdx.x; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDim = Kernel_traits::kHeadDim; const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), Shape, Int>{}, make_stride(params.do_row_stride, _1{})); Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), Shape, Int>{}, make_stride(params.o_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), Shape, Int>{}, make_stride(params.h * params.d_rounded, _1{})); Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), Shape>{}, Stride<_1>{}); typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); // TODO: careful, we're zeroing out dQaccum with type float4, but when // we do atomicAdds, we use type float. The layouts are different. Check this. typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); // Allocate predicate tensors for k Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); // Set predicates for k bounds #pragma unroll for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} Tensor tdOrdO = make_fragment_like(tdOgdO); Tensor tdOrO = make_fragment_like(tdOgO); flash::copy( gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM ); flash::copy( gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM ); // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, // so that (dP - dP_sum) is on the same scale. dot_do_o(tdOrdO, tdOrO, dP_sum, // Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); Kernel_traits::kNThreadsNonWS / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); if (Clear_dQaccum) { // We're actually not zero'ing out all of dQaccum, but only the part that we're going to // do atomicAdds on. Tensor zero = make_fragment_like(tdQgdQaccum); clear(zero); cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void clear_dKVaccum(const Params ¶ms) { using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; const int n_block = blockIdx.x; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; const BlockInfo binfo(params, bidb); if (n_block * kBlockN >= binfo.actual_seqlen_k) return; const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), Shape, Int>{}, Stride, _1>{}); Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), Shape, Int>{}, Stride, _1>{}); typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); Tensor zero = make_fragment_like(tdKgdKaccum); clear(zero); cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum); cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum); } //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert dQ from dQaccum (in float) to fp16/bf16. // This is used in the case where we want to parallelize the backward across seqlen_k. // template template // inline __device__ void convert_dQ(const Params ¶ms, __global__ void convert_dQ(CUTE_GRID_CONSTANT Flash_bwd_params const params, CUTE_GRID_CONSTANT TiledCopydQaccum const tma_load_dQaccum) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; // Shared memory. extern __shared__ char smem_[]; int lane_predicate = cute::elect_one_sync(); int warp_idx = cutlass::canonical_warp_idx_sync(); // Issue Tma Descriptor Prefetch from a single thread if (warp_idx == 0 && lane_predicate) { cute::prefetch_tma_descriptor(tma_load_dQaccum.get_tma_descriptor()); } const int m_block = blockIdx.x; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDim = Kernel_traits::kHeadDim; static constexpr bool dQ_swapAB = Kernel_traits::dQ_swapAB; Tensor mdQaccum = tma_load_dQaccum.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), Shape, Int>{}, make_coord(m_block, _0{})); // (M, K) const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), Shape, Int>{}, make_stride(params.dq_row_stride, _1{})); // Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), // Shape, Int>{}, // make_stride(params.h * params.d_rounded, _1{})); Tensor sdQTMA = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutdQaccTMA{}); Tensor sdQaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutdQacc{}); Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutdQ{}); Tensor sdQt = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutdQt{}); auto &barrier_dQaccum = *reinterpret_cast(smem_ + sizeof(ElementAccum) * size(sdQTMA)); auto block_tma_dQ = tma_load_dQaccum.get_slice(_0{}); Tensor tdQgdQaccumTMA = block_tma_dQ.partition_S(gdQaccum); // (TMA, TMA_M, TMA_K) Tensor tdQsdQaccumTMA = block_tma_dQ.partition_D(sdQTMA); // (TMA, TMA_M, TMA_K) typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); // typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum; // typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; // auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); typename Kernel_traits::TiledMmadQ tiled_mma_dq; auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); // Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); constexpr uint32_t TmaTransactionBytesdQaccum = static_cast(size<0>(sdQTMA) * size<1>(sdQTMA) * cutlass::sizeof_bits_v / 8); if (warp_idx == 0 && lane_predicate) { barrier_dQaccum.init(1 /*numThreads*/); } __syncthreads(); if (warp_idx == 0 && lane_predicate) { barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum); copy(tma_load_dQaccum.with(reinterpret_cast(barrier_dQaccum), 0 /*mcast_mask*/), tdQgdQaccumTMA, tdQsdQaccumTMA); } barrier_dQaccum.wait(0); // if (cute::thread0()) { print_tensor(sdQTMA); printf("\n"); } typename Kernel_traits::RmemTiledCopydQacc rmem_tiled_copy_dQaccum; auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x); Tensor tdQsdQaccum = rmem_thr_copy_dQaccum.partition_S(sdQaccum); Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQsdQaccum)); Tensor tdQrdQaccum = rmem_thr_copy_dQaccum.retile_D(acc_dq); cute::copy(rmem_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum); // Tensor dQ_tmp = make_tensor(acc_dq.data(), flash::convert_layout_acc_rowcol(acc_dq.layout())); // if (blockIdx.x == 0 && threadIdx.x == 0) { print_tensor(dQ_tmp); printf("\n"); } #pragma unroll for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } // Convert acc_dq from fp32 to fp16 Tensor rdQ = flash::convert_type(acc_dq); Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) // dQacc and dQ uses the same shared memory, need to wait for all threads to finish reading smem first __syncthreads(); if constexpr (!dQ_swapAB) { Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); } else { Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt); } __syncthreads(); Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); #pragma unroll for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM ); } //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16. // This is used in the case where we want to parallelize the backward across seqlen_q. template inline __device__ void convert_dKV(const Params ¶ms) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; // Shared memory. extern __shared__ char smem_[]; const int n_block = blockIdx.x; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; static constexpr bool dKV_swapAB = Kernel_traits::dKV_swapAB; const BlockInfo binfo(params, bidb); if (n_block * kBlockN >= binfo.actual_seqlen_k) return; const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), Shape, Int>{}, make_stride(params.dk_row_stride, _1{})); Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), Shape, Int>{}, make_stride(params.dv_row_stride, _1{})); Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), Shape, Int>{}, Stride, _1>{}); Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), Shape, Int>{}, Stride, _1>{}); Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutdKV{}); Tensor sdKt = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutdKVt{}); Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) Tensor sdVt = make_tensor(make_smem_ptr(sdK.data() + size(sdK)), typename Kernel_traits::SmemLayoutdKVt{}); typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV; auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); // typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); typename Kernel_traits::TiledMmadKV tiled_mma_dkv; auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum); cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum); #pragma unroll for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; } #pragma unroll for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; } // Convert acc_dk from fp32 to fp16 Tensor rdK = flash::convert_type(acc_dk); Tensor rdV = flash::convert_type(acc_dv); Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) if constexpr (!dKV_swapAB) { Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); } else { Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt); cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt); } __syncthreads(); Tensor tdKrdK = make_tensor(shape(tdKgdK)); Tensor tdVrdV = make_tensor(shape(tdVgdV)); cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); // if (cute::thread0()) { print_tensor(tdKrdK); printf("\n"); } // if (cute::thread0()) { print_tensor(tdVrdV); printf("\n"); } Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); #pragma unroll for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); flash::copy( gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); } } // namespace flash