/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include #include #include #include #include #include #include "cutlass/pipeline/pipeline.hpp" #include "utils.h" #include "tile_scheduler_bwd.hpp" #include "mainloop_bwd_sm90_tma_gmma_ws.hpp" #include "epilogue_bwd_sm90_tma.hpp" namespace flash { using namespace cute; template class FlashAttnBwd { public: // Type Aliases static constexpr bool Is_causal = CollectiveMainloop_::Is_causal; static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen); static constexpr bool Varlen = CollectiveMainloop_::Varlen; // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP; using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV; using ArchTag = typename CollectiveMainloop::ArchTag; using ClusterShape = typename CollectiveMainloop::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB; // Epilogue derived types using CollectiveEpilogue = CollectiveEpilogue_; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; static_assert(ArchTag::kMinComputeCapability >= 90); using TileScheduler = TileScheduler_; using TileSchedulerArguments = typename TileScheduler::Arguments; using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup; static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static_assert(NumMmaWarpGroups == 2); /// Register requirement for Load and Math WGs static constexpr uint32_t LoadRegisterRequirement = 24; static constexpr uint32_t MmaRegisterRequirement = 240; // If you want to print from the producer warp, you'd need to increase the number of registers // Otherwise you'll get CUDA error. // static constexpr uint32_t LoadRegisterRequirement = 56; // static constexpr uint32_t MmaRegisterRequirement = 224; // Kernel level shared memory storage struct SharedStorage { struct { union { typename CollectiveMainloop::TensorStorage mainloop; typename CollectiveEpilogue::TensorStorage epilogue; }; }; struct { alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV; alignas(16) cutlass::arch::ClusterBarrier barrier_dKV; alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q; alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_do; alignas(16) typename TileScheduler::SharedStorage smem_scheduler; }; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); // Device side arguments struct Arguments { MainloopArguments mainloop{}; EpilogueArguments epilogue{}; cutlass::KernelHardwareInfo hw_info{}; TileSchedulerArguments scheduler{}; }; // Kernel entry point API struct Params { MainloopParams mainloop{}; EpilogueParams epilogue{}; cutlass::KernelHardwareInfo hw_info{}; TileSchedulerParams scheduler{}; }; // // Methods // // Convert to underlying arguments. In this case, a simple copy for the aliased type. static Params to_underlying_arguments(Arguments const& args) { CUTLASS_TRACE_HOST("to_underlying_arguments():"); // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; if (sm_count <= 0) { CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); } CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; return { CollectiveMainloop::to_underlying_arguments(args.mainloop), CollectiveEpilogue::to_underlying_arguments(args.epilogue), hw_info, TileScheduler::to_underlying_arguments(args.scheduler) }; } // Computes the kernel launch grid shape based on runtime parameters static dim3 get_grid_shape(Params const& params) { return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); } static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int kBlockM = get<0>(TileShape_MNK{}); using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename MainloopPipeline::PipelineState; SharedStorage& shared_storage = *reinterpret_cast(smem_buf); int const lane_predicate = cute::elect_one_sync(); int const warp_idx = cutlass::canonical_warp_idx_sync(); // Issue Tma Descriptor Prefetch from a single thread if (warp_idx == 0 && lane_predicate) { CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); } // Obtain warp index int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; PipelineParams pipeline_params; pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE; int warp_group_idx = cutlass::canonical_warp_group_idx(); pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer : MainloopPipeline::ThreadCategory::Consumer; pipeline_params.is_leader = warp_group_thread_idx == 0; pipeline_params.num_consumers = NumMmaThreads; if (warp_idx == 0 && lane_predicate) { shared_storage.barrier_KV.init(1 /*numThreads*/); // shared_storage.barrier_dKV.init(size(ClusterShape{}) /*numThreads*/); } // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init(); MainloopPipeline pipeline_q(shared_storage.pipeline_q, pipeline_params, ClusterShape{}); MainloopPipeline pipeline_do(shared_storage.pipeline_do, pipeline_params, ClusterShape{}); CollectiveMainloop collective_mainloop; CollectiveEpilogue collective_epilogue; // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { cute::cluster_arrive_relaxed(); cute::cluster_wait(); } else { __syncthreads(); } if (warp_group_idx == 0) { // Producer cutlass::arch::warpgroup_reg_dealloc(); int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO PipelineState smem_pipe_write = cutlass::make_producer_start_state(); int work_idx = 0; TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { auto block_coord = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb] = block_coord; if constexpr (Varlen) { if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { scheduler.prefetch_next_work(params.scheduler, work_tile_info); continue; } } if constexpr (Is_causal) { int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM); if (m_block_min >= m_block_max) { scheduler.prefetch_next_work(params.scheduler, work_tile_info); continue; } } auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() { scheduler.prefetch_next_work(params.scheduler, work_tile_info); }; collective_mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write, shared_storage, scheduler_prefetch, block_coord, work_idx); ++work_idx; } collective_mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write); } else if (warp_idx_in_warpgroup == 1) { TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { auto block_coord = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb] = block_coord; if constexpr (Varlen) { if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; } } if constexpr (Is_causal) { int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM); if (m_block_min >= m_block_max) { continue; } } collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord); } } } else { // Consumer cutlass::arch::warpgroup_reg_alloc(); TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); // Initialize matmul objects. TiledMmadKV tiled_mma_dKV; PipelineState smem_pipe_read; collective_mainloop.mma_init(); scheduler.init_consumer(); int work_idx = 0; CUTLASS_PRAGMA_NO_UNROLL for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { auto block_coord = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb] = block_coord; if constexpr (Varlen) { if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; } } if constexpr (Is_causal) { int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM); if (m_block_min >= m_block_max) { // We exit early and write 0 to dK and dV collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); continue; } } // dK and dV output accumulator. Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); collective_mainloop.mma(params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage); collective_epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, threadIdx.x - NumCopyThreads, block_coord); ++work_idx; } collective_epilogue.store_tail(); } } }; } // namespace flash