123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- /******************************************************************************
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
- ******************************************************************************/
- #pragma once
- #include "cute/tensor.hpp"
- #include <cutlass/cutlass.h>
- #include <cutlass/array.h>
- #include <cutlass/numeric_types.h>
- #include <cutlass/numeric_conversion.h>
- #include "cutlass/arch/barrier.h"
- #include "utils.h"
- namespace flash {
- using namespace cute;
- template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class SmemLayoutdQaccumTMA,
- class TiledMma, bool dQ_swapAB>
- class FlashAttnBwdPostprocessConvertdQ {
- public:
- // Type Aliases
- using TileShape_MK = TileShape_MK_;
- using ArchTag = ArchTag_;
- static_assert(ArchTag::kMinComputeCapability >= 90);
- static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
- static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
- static constexpr int kHeadDim = get<1>(TileShape_MK{});
- using R2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreads>>, Stride<_1>>;
- using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, R2SLayoutAtomdQaccum{},
- Layout<Shape < _4>>{})); // Val layout, 4 vals per read
- static constexpr int SmemdQaccumSize = size(TileShape_MK{});
- static_assert(size(TileShape_MK{}) == size(SmemLayoutdQaccumTMA{}), "TileShape_MK and SmemLayoutdQaccumTMA must have the same size");
- using SmemLayoutdQaccum = Layout<Shape<Int<SmemdQaccumSize>>, Stride<_1>>;
- // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
- // then setting kBlockKSmem to 32 will cause "Static shape_div failure".
- // We want to treat it as 64 x 48, so kBlockKSmem should be 16.
- static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{});
- static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16);
- static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
- using SmemLayoutAtomdQ =
- decltype(composition(Swizzle<kSwizzle, 3, 3>{},
- Layout<Shape<Int<8>, Int<kBlockKSmem>>,
- Stride<Int<kBlockKSmem>, _1>>{}));
- using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{}));
- using SmemLayoutdQt =
- decltype(cute::composition(SmemLayoutdQ{},
- make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})),
- make_stride(Int<get<0>(TileShape_MK{})>{}, _1{}))));
- using SmemCopyAtomdQ = Copy_Atom<
- std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
- Element>;
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
- static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
- static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock));
- static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
- using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
- Stride<Int<kGmemThreadsPerRow>, _1>>;
- using GmemTiledCopy = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
- GmemLayoutAtom{},
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
- using GmemTiledCopydQaccum = cute::SM90_TMA_LOAD;
- struct SharedStorage : cute::aligned_struct<128> {
- cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccumTMA>, 1024> smem_dqacc;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
- alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum;
- };
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
- using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
- using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;
- using TMA_dQaccum = decltype(make_tma_copy(
- GmemTiledCopydQaccum{},
- make_tensor(make_gmem_ptr(static_cast<ElementAccum*>(nullptr)), ShapedQ{}, StridedQ{}),
- SmemLayoutdQaccumTMA{},
- SmemLayoutdQaccumTMA{}.shape(),
- _1{})); // no mcast for dQ
- // Device side arguments
- struct Arguments {
- ElementAccum const* ptr_dQaccum;
- ShapedQ const shape_dQaccum;
- StridedQ const stride_dQaccum;
- Element* ptr_dQ;
- ShapedQ const shape_dQ;
- StridedQ const stride_dQ;
- float const softmax_scale;
- int const* cu_seqlens = nullptr;
- int const* seqused = nullptr;
- };
- // Kernel entry point API
- struct Params {
- TMA_dQaccum tma_load_dQaccum;
- ShapedQ const shape_dQaccum;
- Element* ptr_dQ;
- ShapedQ const shape_dQ;
- StridedQ const stride_dQ;
- float const softmax_scale;
- int const* cu_seqlens = nullptr;
- int const* seqused = nullptr;
- };
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
- static
- Params
- to_underlying_arguments(Arguments const& args) {
- Tensor mdQaccum = make_tensor(make_gmem_ptr(args.ptr_dQaccum), args.shape_dQaccum, args.stride_dQaccum);
- TMA_dQaccum tma_load_dQaccum = make_tma_copy(
- GmemTiledCopydQaccum{},
- mdQaccum,
- SmemLayoutdQaccumTMA{},
- SmemLayoutdQaccumTMA{}.shape(),
- _1{}); // no mcast for dQaccum
- return {
- tma_load_dQaccum,
- args.shape_dQaccum,
- args.ptr_dQ,
- args.shape_dQ,
- args.stride_dQ,
- args.softmax_scale,
- args.cu_seqlens,
- args.seqused
- };
- }
- CUTLASS_DEVICE
- void
- operator()(Params const& params, char* smem_buf) {
- static constexpr int kBlockM = get<0>(TileShape_MK{});
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
- Tensor sdQaccumTMA = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumTMA{});
- // Tensor sdQaccumTMAnoswizzle = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumTMANoSwizzle{});
- Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{});
- Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{});
- Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{});
- int const thread_idx = threadIdx.x;
- int const m_block = blockIdx.x;
- int const bidh = blockIdx.y;
- int const bidb = blockIdx.z;
- bool const is_varlen = params.cu_seqlens != nullptr;
- int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]);
- if (is_varlen && m_block * kBlockM >= seqlen) { return; }
- int lane_predicate = cute::elect_one_sync();
- int warp_idx = cutlass::canonical_warp_idx_sync();
- // Issue Tma Descriptor Prefetch from a single thread
- if (warp_idx == 0 && lane_predicate) {
- cute::prefetch_tma_descriptor(params.tma_load_dQaccum.get_tma_descriptor());
- shared_storage.barrier_dQaccum.init(1 /*numThreads*/);
- }
- __syncthreads();
- // Step 1: TMA to load dQaccum from gmem to smem
- // We reshaped dQaccum to have last dimension 32, so the offset needs to be multiplied by kHeadDim / 32
- int const offset_padded = !is_varlen ? 0 : ((params.cu_seqlens[bidb] + bidb * 128) / 128 * 128) * (kHeadDim / get<1>(SmemLayoutdQaccumTMA{}.shape()));
- Tensor mdQaccum = params.tma_load_dQaccum.get_tma_tensor(params.shape_dQaccum)(_, _, bidh, !is_varlen ? bidb : 0);
- Tensor gdQaccum = local_tile(domain_offset(make_coord(offset_padded, _0{}), mdQaccum), SmemLayoutdQaccumTMA{}.shape(), make_coord(m_block, _0{})); // (M, K)
- auto block_tma_dQ = params.tma_load_dQaccum.get_slice(_0{});
- Tensor tdQgdQaccumTMA = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K)
- Tensor tdQsdQaccumTMA = block_tma_dQ.partition_S(sdQaccumTMA); // (TMA, TMA_M, TMA_K)
- static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumTMA{}) * cute::sizeof_bits_v<ElementAccum> / 8);
- if (warp_idx == 0 && lane_predicate) {
- shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
- copy(params.tma_load_dQaccum.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_dQaccum), 0 /*mcast_mask*/), tdQgdQaccumTMA, tdQsdQaccumTMA);
- }
- shared_storage.barrier_dQaccum.wait(0);
- // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccumTMA); }
- // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccumTMAnoswizzle); }
- // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); }
- // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16
- R2STiledCopydQaccum s2r_tiled_copy_dQaccum;
- auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx);
- Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum);
- TiledMma tiled_mma_dQ;
- Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{}));
- // if (cute::thread0()) { print(tiled_mma_dQ); printf("\n"); }
- // if (cute::thread0()) { print(tdQsdQaccum); }
- // if (cute::thread0()) { print(taccdQrdQaccum); }
- CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum));
- Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum);
- cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
- #pragma unroll
- for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; }
- // Convert tdQrdQ from fp32 to fp16
- Tensor rdQ = flash::convert_type<Element>(taccdQrdQaccum);
- // Step 3: Copy dQ from register to smem
- auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ);
- auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx);
- Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
- // if (cute::thread0()) { print(smem_tiled_copy_dQ); }
- // if (cute::thread0()) { print(smem_thr_copy_dQ); }
- // if (cute::thread0()) { print(sdQ); }
- if constexpr (!dQ_swapAB) {
- Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
- cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
- } else {
- Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
- cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt);
- }
- __syncthreads();
- // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
- int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
- Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0);
- Tensor gdQ = local_tile(domain_offset(make_coord(offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
- GmemTiledCopy gmem_tiled_copy_dQ;
- auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx);
- Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
- Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
- Tensor tdQrdQ = make_fragment_like(tdQsdQ);
- cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
- // Step 5: Copy dQ from register to gmem
- // Construct identity layout for gdQ
- Tensor cdQ = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
- // Repeat the partitioning with identity layouts
- Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
- Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
- #pragma unroll
- for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); }
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
- gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, seqlen - m_block * kBlockM
- );
- }
- };
- } // namespace flash
|