123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- /******************************************************************************
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
- ******************************************************************************/
- #pragma once
- #include "cute/tensor.hpp"
- #include <cutlass/cutlass.h>
- #include <cutlass/array.h>
- #include <cutlass/numeric_types.h>
- #include <cutlass/numeric_conversion.h>
- #include "utils.h"
- namespace flash {
- using namespace cute;
- template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, bool Clear_dQaccum, bool Varlen>
- class FlashAttnBwdPreprocess {
- public:
- // Type Aliases
- using TileShape_MK = TileShape_MK_;
- using ArchTag = ArchTag_;
- static_assert(std::is_same_v<Element, cutlass::half_t> && ArchTag::kMinComputeCapability >= 75 ||
- std::is_same_v<Element, cutlass::bfloat16_t> && ArchTag::kMinComputeCapability >= 80 ||
- std::is_same_v<Element, cutlass::float_e4m3_t> && ArchTag::kMinComputeCapability >= 89);
- static constexpr uint32_t MaxThreadsPerBlock = 256;
- static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
- static constexpr int SharedStorageSize = 0;
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
- static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
- static constexpr int kHeadDim = get<1>(TileShape_MK{});
- // We want kBlockKGmem to be a power of 2 so that when we do the summing,
- // it's just between threads in the same warp
- static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
- static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
- static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
- using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
- Stride<Int<kGmemThreadsPerRow>, _1>>;
- using GmemTiledCopy = decltype(
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
- GmemLayoutAtom{},
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
- static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
- static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoadAccum == 0, "Headdim must be a multiple of kGmemElemsPerLoadAccum");
- static constexpr int kGmemThreadsPerRowAccum = kBlockKGmem / kGmemElemsPerLoadAccum;
- static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowAccum == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowAccum");
- using GmemLayoutAtomAccum = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
- Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
- using GmemTiledCopyAccum = decltype(
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
- GmemLayoutAtomAccum{},
- Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
- using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
- using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
- using ShapedPsum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q, head, batch)
- using StridedPsum = cute::Stride<_1, int64_t, int64_t>;
- // Device side arguments
- struct Arguments {
- Element const* ptr_O;
- ShapeO const shape_O;
- StrideO const stride_O;
- Element const* ptr_dO;
- StrideO const stride_dO;
- float* ptr_dPsum;
- ShapedPsum const shape_dPsum;
- StridedPsum const stride_dPsum;
- float const* ptr_LSE;
- StridedPsum const stride_LSE;
- float *ptr_LSE_log2;
- StridedPsum const stride_LSE_log2;
- ElementAccum* ptr_dQaccum;
- ShapeO const shape_dQaccum;
- StrideO const stride_dQaccum;
- int num_batch; // We need this to know the size of dq_semaphore in case of varlen
- int* dq_semaphore;
- int const* cu_seqlens = nullptr;
- int const* seqused = nullptr;
- };
- // Kernel entry point API
- struct Params {
- Element const* ptr_O;
- ShapeO const shape_O;
- StrideO const stride_O;
- Element const* ptr_dO;
- StrideO const stride_dO;
- float* ptr_dPsum;
- ShapedPsum const shape_dPsum;
- StridedPsum const stride_dPsum;
- float const* ptr_LSE;
- StridedPsum const stride_LSE;
- float* ptr_LSE_log2;
- StridedPsum const stride_LSE_log2;
- ElementAccum* ptr_dQaccum;
- ShapeO const shape_dQaccum;
- StrideO const stride_dQaccum;
- int num_batch;
- int* dq_semaphore;
- int const* cu_seqlens = nullptr;
- int const* seqused = nullptr;
- };
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
- static
- Params
- to_underlying_arguments(Arguments const& args) {
- return {
- args.ptr_O,
- args.shape_O,
- args.stride_O,
- args.ptr_dO,
- args.stride_dO,
- args.ptr_dPsum,
- args.shape_dPsum,
- args.stride_dPsum,
- args.ptr_LSE,
- args.stride_LSE,
- args.ptr_LSE_log2,
- args.stride_LSE_log2,
- args.ptr_dQaccum,
- args.shape_dQaccum,
- args.stride_dQaccum,
- args.num_batch,
- args.dq_semaphore,
- args.cu_seqlens,
- args.seqused
- };
- }
- CUTLASS_DEVICE
- void
- operator()(Params const& params, [[maybe_unused]] char* smem_buf) {
- static constexpr int kBlockM = get<0>(TileShape_MK{});
- int const thread_idx = threadIdx.x;
- int const m_block = blockIdx.x;
- int const bidh = blockIdx.y;
- int const bidb = blockIdx.z;
- bool const is_varlen = Varlen && params.cu_seqlens != nullptr;
- int const offset_o = !is_varlen ? 0 : params.cu_seqlens[bidb];
- int const seqlen_o = !is_varlen ? get<0>(params.shape_O) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset_o);
- if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }
- Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);
- Tensor gO = local_tile(cute::domain_offset(make_coord(offset_o, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
- Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0);
- Tensor gdO = local_tile(cute::domain_offset(make_coord(offset_o, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
- auto shape_LSE = select<0, 2, 3>(params.shape_O);
- Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0);
- Tensor gLSE = local_tile(cute::domain_offset(make_coord(offset_o), mLSE), Shape<Int<kBlockM>>{}, make_coord(m_block));
- static_assert(kBlockM <= MaxThreadsPerBlock);
- float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY;
- GmemTiledCopy gmem_tiled_copy_O;
- auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
- Tensor tOgO = gmem_thr_copy_O.partition_S(gO);
- Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO);
- // Construct identity layout for gO
- Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
- // Repeat the partitioning with identity layouts
- Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
- Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
- #pragma unroll
- for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
- // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128)
- Tensor tOrO = make_fragment_like(tOgO);
- Tensor tOrdO = make_fragment_like(tOgdO);
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
- gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM
- );
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
- gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM
- );
- // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64))
- Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout())));
- Tensor o_fp32 = flash::convert_type<float>(make_tensor(tOrO.data(), l));
- Tensor do_fp32 = flash::convert_type<float>(make_tensor(tOrdO.data(), l));
- // Sum across the last dimension
- Tensor dP_sum = make_tensor<float>(make_shape(size<0>(o_fp32)));
- #pragma unroll
- for (int mi = 0; mi < size<0>(o_fp32); ++mi) {
- float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
- #pragma unroll
- for (int ni = 1; ni < size<1>(o_fp32); ni++) {
- dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
- }
- flash::SumOp<float> sum_op;
- dP_sum(mi) = flash::Allreduce<kGmemThreadsPerRow>::run(dP_sum_cur, sum_op);
- }
- // If varlen, the layout for dPSum, LSE_log2, and dQaccum is that we pad each sequence in the batch
- // by an extra 128, so that the write for each sequence doesn't touch the next sequence.
- // Sequence i starts at params.cu_seqlens[i] + i * 128 and ends at params.cu_seqlens[i + 1] + i * 128
- int const offset_padded = !is_varlen ? 0 : (params.cu_seqlens[bidb] + bidb * kBlockM) / kBlockM * kBlockM;
- Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0);
- Tensor gdPsum = local_tile(cute::domain_offset(make_coord(offset_padded), mdPsum), Shape<Int<kBlockM>>{}, make_coord(m_block));
- if (thread_idx % kGmemThreadsPerRow == 0) {
- #pragma unroll
- for (int mi = 0; mi < size(dP_sum); ++mi) {
- int row = thread_idx / kGmemThreadsPerRow + mi * MaxThreadsPerBlock / kGmemThreadsPerRow;
- gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0;
- }
- }
- int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM);
- Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0);
- Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(offset_padded), mLSElog2), Shape<Int<kBlockM>>{}, make_coord(m_block));
- if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) {
- gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E);
- }
- if constexpr (Clear_dQaccum) {
- Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, _, bidh, !is_varlen ? bidb : 0);
- Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(offset_padded, _0{}), mdQaccum), TileShape_MK{}, make_coord(m_block, _0{}));
- GmemTiledCopyAccum gmem_tiled_copy_dQaccum;
- auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx);
- Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
- Tensor zero = make_fragment_like(tdQgdQaccum);
- clear(zero);
- // cute::copy(zero, tdQgdQaccum); // Somehow this doesn't vectorize the write
- #pragma unroll
- for (int m = 0; m < size<1>(zero); ++m) {
- cute::copy(zero(_, m, _), tdQgdQaccum(_, m, _));
- }
- }
- if (params.dq_semaphore != nullptr && thread_idx == 0) {
- int const num_batch = params.num_batch;
- int const num_head = get<2>(params.shape_dQaccum);
- params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0;
- }
- }
- };
- } // namespace flash
|