/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include #include #include #include #include "cutlass/arch/barrier.h" #include "utils.h" namespace flash { using namespace cute; template class FlashAttnBwdPostprocessConvertdQ { public: // Type Aliases using TileShape_MK = TileShape_MK_; using ArchTag = ArchTag_; static_assert(ArchTag::kMinComputeCapability >= 90); static constexpr uint32_t MaxThreadsPerBlock = kNThreads; static constexpr uint32_t MinBlocksPerMultiprocessor = 2; static constexpr int kHeadDim = get<1>(TileShape_MK{}); using R2SLayoutAtomdQaccum = Layout>, Stride<_1>>; using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom{}, R2SLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per read static constexpr int SmemdQaccumSize = size(TileShape_MK{}); static_assert(size(TileShape_MK{}) == size(SmemLayoutdQaccumTMA{}), "TileShape_MK and SmemLayoutdQaccumTMA must have the same size"); using SmemLayoutdQaccum = Layout>, Stride<_1>>; // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs, // then setting kBlockKSmem to 32 will cause "Static shape_div failure". // We want to treat it as 64 x 48, so kBlockKSmem should be 16. static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{}); static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16); static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1); using SmemLayoutAtomdQ = decltype(composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{})); using SmemLayoutdQt = decltype(cute::composition(SmemLayoutdQ{}, make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})), make_stride(Int(TileShape_MK{})>{}, _1{})))); using SmemCopyAtomdQ = Copy_Atom< std::conditional_t, Element>; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock)); static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per load using GmemTiledCopydQaccum = cute::SM90_TMA_LOAD; struct SharedStorage : cute::aligned_struct<128> { cute::array_aligned, 1024> smem_dqacc; cute::array_aligned> smem_dq; alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); using ShapedQ = cute::Shape; // (seqlen_q, d, head, batch) using StridedQ = cute::Stride; using TMA_dQaccum = decltype(make_tma_copy( GmemTiledCopydQaccum{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapedQ{}, StridedQ{}), SmemLayoutdQaccumTMA{}, SmemLayoutdQaccumTMA{}.shape(), _1{})); // no mcast for dQ // Device side arguments struct Arguments { ElementAccum const* ptr_dQaccum; ShapedQ const shape_dQaccum; StridedQ const stride_dQaccum; Element* ptr_dQ; ShapedQ const shape_dQ; StridedQ const stride_dQ; float const softmax_scale; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; // Kernel entry point API struct Params { TMA_dQaccum tma_load_dQaccum; ShapedQ const shape_dQaccum; Element* ptr_dQ; ShapedQ const shape_dQ; StridedQ const stride_dQ; float const softmax_scale; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. static Params to_underlying_arguments(Arguments const& args) { Tensor mdQaccum = make_tensor(make_gmem_ptr(args.ptr_dQaccum), args.shape_dQaccum, args.stride_dQaccum); TMA_dQaccum tma_load_dQaccum = make_tma_copy( GmemTiledCopydQaccum{}, mdQaccum, SmemLayoutdQaccumTMA{}, SmemLayoutdQaccumTMA{}.shape(), _1{}); // no mcast for dQaccum return { tma_load_dQaccum, args.shape_dQaccum, args.ptr_dQ, args.shape_dQ, args.stride_dQ, args.softmax_scale, args.cu_seqlens, args.seqused }; } CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { static constexpr int kBlockM = get<0>(TileShape_MK{}); SharedStorage& shared_storage = *reinterpret_cast(smem_buf); Tensor sdQaccumTMA = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumTMA{}); // Tensor sdQaccumTMAnoswizzle = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumTMANoSwizzle{}); Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{}); Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{}); Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{}); int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const bidh = blockIdx.y; int const bidb = blockIdx.z; bool const is_varlen = params.cu_seqlens != nullptr; int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]); if (is_varlen && m_block * kBlockM >= seqlen) { return; } int lane_predicate = cute::elect_one_sync(); int warp_idx = cutlass::canonical_warp_idx_sync(); // Issue Tma Descriptor Prefetch from a single thread if (warp_idx == 0 && lane_predicate) { cute::prefetch_tma_descriptor(params.tma_load_dQaccum.get_tma_descriptor()); shared_storage.barrier_dQaccum.init(1 /*numThreads*/); } __syncthreads(); // Step 1: TMA to load dQaccum from gmem to smem // We reshaped dQaccum to have last dimension 32, so the offset needs to be multiplied by kHeadDim / 32 int const offset_padded = !is_varlen ? 0 : ((params.cu_seqlens[bidb] + bidb * 128) / 128 * 128) * (kHeadDim / get<1>(SmemLayoutdQaccumTMA{}.shape())); Tensor mdQaccum = params.tma_load_dQaccum.get_tma_tensor(params.shape_dQaccum)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdQaccum = local_tile(domain_offset(make_coord(offset_padded, _0{}), mdQaccum), SmemLayoutdQaccumTMA{}.shape(), make_coord(m_block, _0{})); // (M, K) auto block_tma_dQ = params.tma_load_dQaccum.get_slice(_0{}); Tensor tdQgdQaccumTMA = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K) Tensor tdQsdQaccumTMA = block_tma_dQ.partition_S(sdQaccumTMA); // (TMA, TMA_M, TMA_K) static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast(size(SmemLayoutdQaccumTMA{}) * cute::sizeof_bits_v / 8); if (warp_idx == 0 && lane_predicate) { shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum); copy(params.tma_load_dQaccum.with(reinterpret_cast(shared_storage.barrier_dQaccum), 0 /*mcast_mask*/), tdQgdQaccumTMA, tdQsdQaccumTMA); } shared_storage.barrier_dQaccum.wait(0); // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccumTMA); } // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccumTMAnoswizzle); } // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); } // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16 R2STiledCopydQaccum s2r_tiled_copy_dQaccum; auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx); Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum); TiledMma tiled_mma_dQ; Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select(TileShape_MK{})); // if (cute::thread0()) { print(tiled_mma_dQ); printf("\n"); } // if (cute::thread0()) { print(tdQsdQaccum); } // if (cute::thread0()) { print(taccdQrdQaccum); } CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum)); Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum); cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum); #pragma unroll for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; } // Convert tdQrdQ from fp32 to fp16 Tensor rdQ = flash::convert_type(taccdQrdQaccum); // Step 3: Copy dQ from register to smem auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ); auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx); Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) // if (cute::thread0()) { print(smem_tiled_copy_dQ); } // if (cute::thread0()) { print(smem_thr_copy_dQ); } // if (cute::thread0()) { print(sdQ); } if constexpr (!dQ_swapAB) { Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); } else { Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt); } __syncthreads(); // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb]; Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdQ = local_tile(domain_offset(make_coord(offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) GmemTiledCopy gmem_tiled_copy_dQ; auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx); Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); Tensor tdQrdQ = make_fragment_like(tdQsdQ); cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); // Step 5: Copy dQ from register to gmem // Construct identity layout for gdQ Tensor cdQ = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); #pragma unroll for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, seqlen - m_block * kBlockM ); } }; } // namespace flash