/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include #include #include #include #include "utils.h" namespace flash { using namespace cute; template class FlashAttnBwdPreprocess { public: // Type Aliases using TileShape_MK = TileShape_MK_; using ArchTag = ArchTag_; static_assert(std::is_same_v && ArchTag::kMinComputeCapability >= 75 || std::is_same_v && ArchTag::kMinComputeCapability >= 80 || std::is_same_v && ArchTag::kMinComputeCapability >= 89); static constexpr uint32_t MaxThreadsPerBlock = 256; static constexpr uint32_t MinBlocksPerMultiprocessor = 2; static constexpr int SharedStorageSize = 0; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); static constexpr int kHeadDim = get<1>(TileShape_MK{}); // We want kBlockKGmem to be a power of 2 so that when we do the summing, // it's just between threads in the same warp static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per load static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoadAccum == 0, "Headdim must be a multiple of kGmemElemsPerLoadAccum"); static constexpr int kGmemThreadsPerRowAccum = kBlockKGmem / kGmemElemsPerLoadAccum; static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowAccum == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowAccum"); using GmemLayoutAtomAccum = Layout, Int>, Stride, _1>>; using GmemTiledCopyAccum = decltype( make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomAccum{}, Layout>>{})); // Val layout, 4 vals per store using ShapeO = cute::Shape; // (seqlen_q, d, head, batch) using StrideO = cute::Stride; using ShapedPsum = cute::Shape; // (seqlen_q, head, batch) using StridedPsum = cute::Stride<_1, int64_t, int64_t>; // Device side arguments struct Arguments { Element const* ptr_O; ShapeO const shape_O; StrideO const stride_O; Element const* ptr_dO; StrideO const stride_dO; float* ptr_dPsum; ShapedPsum const shape_dPsum; StridedPsum const stride_dPsum; float const* ptr_LSE; StridedPsum const stride_LSE; float *ptr_LSE_log2; StridedPsum const stride_LSE_log2; ElementAccum* ptr_dQaccum; ShapeO const shape_dQaccum; StrideO const stride_dQaccum; int num_batch; // We need this to know the size of dq_semaphore in case of varlen int* dq_semaphore; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; // Kernel entry point API struct Params { Element const* ptr_O; ShapeO const shape_O; StrideO const stride_O; Element const* ptr_dO; StrideO const stride_dO; float* ptr_dPsum; ShapedPsum const shape_dPsum; StridedPsum const stride_dPsum; float const* ptr_LSE; StridedPsum const stride_LSE; float* ptr_LSE_log2; StridedPsum const stride_LSE_log2; ElementAccum* ptr_dQaccum; ShapeO const shape_dQaccum; StrideO const stride_dQaccum; int num_batch; int* dq_semaphore; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. static Params to_underlying_arguments(Arguments const& args) { return { args.ptr_O, args.shape_O, args.stride_O, args.ptr_dO, args.stride_dO, args.ptr_dPsum, args.shape_dPsum, args.stride_dPsum, args.ptr_LSE, args.stride_LSE, args.ptr_LSE_log2, args.stride_LSE_log2, args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, args.num_batch, args.dq_semaphore, args.cu_seqlens, args.seqused }; } CUTLASS_DEVICE void operator()(Params const& params, [[maybe_unused]] char* smem_buf) { static constexpr int kBlockM = get<0>(TileShape_MK{}); int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const bidh = blockIdx.y; int const bidb = blockIdx.z; bool const is_varlen = Varlen && params.cu_seqlens != nullptr; int const offset_o = !is_varlen ? 0 : params.cu_seqlens[bidb]; int const seqlen_o = !is_varlen ? get<0>(params.shape_O) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset_o); if (is_varlen && m_block * kBlockM >= seqlen_o) { return; } Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gO = local_tile(cute::domain_offset(make_coord(offset_o, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdO = local_tile(cute::domain_offset(make_coord(offset_o, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) auto shape_LSE = select<0, 2, 3>(params.shape_O); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0); Tensor gLSE = local_tile(cute::domain_offset(make_coord(offset_o), mLSE), Shape>{}, make_coord(m_block)); static_assert(kBlockM <= MaxThreadsPerBlock); float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY; GmemTiledCopy gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); Tensor tOgO = gmem_thr_copy_O.partition_S(gO); Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO); // Construct identity layout for gO Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O.partition_D(cO); Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128) Tensor tOrO = make_fragment_like(tOgO); Tensor tOrdO = make_fragment_like(tOgdO); flash::copy( gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM ); flash::copy( gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM ); // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64)) Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout()))); Tensor o_fp32 = flash::convert_type(make_tensor(tOrO.data(), l)); Tensor do_fp32 = flash::convert_type(make_tensor(tOrdO.data(), l)); // Sum across the last dimension Tensor dP_sum = make_tensor(make_shape(size<0>(o_fp32))); #pragma unroll for (int mi = 0; mi < size<0>(o_fp32); ++mi) { float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); #pragma unroll for (int ni = 1; ni < size<1>(o_fp32); ni++) { dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); } flash::SumOp sum_op; dP_sum(mi) = flash::Allreduce::run(dP_sum_cur, sum_op); } // If varlen, the layout for dPSum, LSE_log2, and dQaccum is that we pad each sequence in the batch // by an extra 128, so that the write for each sequence doesn't touch the next sequence. // Sequence i starts at params.cu_seqlens[i] + i * 128 and ends at params.cu_seqlens[i + 1] + i * 128 int const offset_padded = !is_varlen ? 0 : (params.cu_seqlens[bidb] + bidb * kBlockM) / kBlockM * kBlockM; Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0); Tensor gdPsum = local_tile(cute::domain_offset(make_coord(offset_padded), mdPsum), Shape>{}, make_coord(m_block)); if (thread_idx % kGmemThreadsPerRow == 0) { #pragma unroll for (int mi = 0; mi < size(dP_sum); ++mi) { int row = thread_idx / kGmemThreadsPerRow + mi * MaxThreadsPerBlock / kGmemThreadsPerRow; gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0; } } int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM); Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0); Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(offset_padded), mLSElog2), Shape>{}, make_coord(m_block)); if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) { gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E); } if constexpr (Clear_dQaccum) { Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(offset_padded, _0{}), mdQaccum), TileShape_MK{}, make_coord(m_block, _0{})); GmemTiledCopyAccum gmem_tiled_copy_dQaccum; auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); Tensor zero = make_fragment_like(tdQgdQaccum); clear(zero); // cute::copy(zero, tdQgdQaccum); // Somehow this doesn't vectorize the write #pragma unroll for (int m = 0; m < size<1>(zero); ++m) { cute::copy(zero(_, m, _), tdQgdQaccum(_, m, _)); } } if (params.dq_semaphore != nullptr && thread_idx == 0) { int const num_batch = params.num_batch; int const num_head = get<2>(params.shape_dQaccum); params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0; } } }; } // namespace flash