123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- /******************************************************************************
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
- ******************************************************************************/
- #pragma once
- #include "cutlass/fast_math.h"
- #include "cutlass/arch/barrier.h"
- #include "named_barrier.hpp"
- namespace flash {
- ///////////////////////////////////////////////////////////////////////////////
- struct SingleTileScheduler {
- public:
- // Host side kernel arguments
- struct Arguments {
- int const num_blocks_m, num_splits, num_head, num_batch;
- int* const tile_count_semaphore = nullptr;
- };
- // Device side kernel params
- struct Params {};
- static Params
- to_underlying_arguments(Arguments const& args) {
- return {};
- }
- static dim3
- get_grid_dim(Arguments const& args, int num_sm) {
- return {uint32_t(args.num_blocks_m), uint32_t(args.num_head), uint32_t(args.num_batch)};
- }
- struct WorkTileInfo {
- int M_idx = 0;
- int H_idx = 0;
- int B_idx = 0;
- bool is_valid_tile = false;
- CUTLASS_DEVICE
- bool
- is_valid(Params const& params) const {
- return is_valid_tile;
- }
- CUTLASS_DEVICE
- cute::tuple<int32_t, int32_t, int32_t, int32_t>
- get_block_coord(Params const& params) const {
- return {M_idx, 1, H_idx, B_idx};
- }
- };
- CUTLASS_DEVICE
- SingleTileScheduler(int* tile_count_smem_) { }
- CUTLASS_DEVICE
- WorkTileInfo
- get_initial_work() const {
- return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
- }
- CUTLASS_DEVICE
- void
- init_consumer() const {}
- CUTLASS_DEVICE
- void
- prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
- CUTLASS_DEVICE
- void
- broadcast_next_work(WorkTileInfo& current_work) const {}
- template<bool IsProducer=false>
- CUTLASS_DEVICE
- WorkTileInfo
- get_next_work(Params const& params, WorkTileInfo const& current_work) const {
- return {-1, -1, -1, false};
- }
- };
- ///////////////////////////////////////////////////////////////////////////////
- template <bool Is_split = false>
- class StaticPersistentTileScheduler {
- public:
- // Host side kernel arguments
- struct Arguments {
- int const num_blocks_m, num_splits, num_head, num_batch;
- int* const tile_count_semaphore = nullptr;
- };
- // Device side kernel params
- struct Params {
- int const total_blocks;
- cutlass::FastDivmod const m_block_divmod, split_divmod, head_divmod;
- };
- static Params
- to_underlying_arguments(Arguments const& args) {
- // return {args.num_blocks_m * args.num_head * args.num_batch,
- // cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head)};
- return {args.num_blocks_m * args.num_splits * args.num_head * args.num_batch,
- cutlass::FastDivmod(args.num_blocks_m),
- cutlass::FastDivmod(args.num_splits),
- cutlass::FastDivmod(args.num_head)};
- }
- static dim3
- get_grid_dim(Arguments const& args, int num_sm) {
- return {uint32_t(num_sm)};
- }
- struct WorkTileInfo {
- int tile_idx;
- CUTLASS_DEVICE
- bool
- is_valid(Params const& params) const {
- return tile_idx < params.total_blocks;
- }
- CUTLASS_DEVICE
- cute::tuple<int32_t, int32_t, int32_t, int32_t>
- get_block_coord(Params const& params) const {
- int m_block, split_idx, bidh, bidb;
- if constexpr(!Is_split) {
- bidb = params.head_divmod.divmod(bidh,
- params.m_block_divmod.divmod(m_block, tile_idx));
- return {m_block, 1, bidh, bidb};
- } else {
- bidb = params.head_divmod.divmod(bidh,
- params.split_divmod.divmod(split_idx,
- params.m_block_divmod.divmod(m_block, tile_idx)));
- return {m_block, split_idx, bidh, bidb};
- }
- }
- };
- CUTLASS_DEVICE
- StaticPersistentTileScheduler(int* tile_count_smem_) {};
- CUTLASS_DEVICE
- WorkTileInfo
- get_initial_work() const {
- return {int(blockIdx.x)};
- }
- CUTLASS_DEVICE
- void
- init_consumer() const {}
- CUTLASS_DEVICE
- void
- prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
- CUTLASS_DEVICE
- void
- broadcast_next_work(WorkTileInfo& current_work) const {}
- template<bool IsProducer=false>
- CUTLASS_DEVICE
- WorkTileInfo
- get_next_work(Params const& params, WorkTileInfo const& current_work) const {
- return {current_work.tile_idx + int(gridDim.x)};
- }
- };
- template<int NumMmaThreads = 2 * cutlass::NumThreadsPerWarpGroup,
- int NumProducerThreads = cutlass::NumThreadsPerWarp,
- bool Is_split = false>
- class DynamicPersistentTileScheduler {
- protected:
- int* const tile_count_smem;
- public:
- // Host side kernel arguments
- struct Arguments {
- int const num_blocks_m, num_splits, num_head, num_batch;
- int* const tile_count_semaphore;
- };
- // Device side kernel params
- struct Params {
- int const total_blocks;
- cutlass::FastDivmod const m_block_divmod, split_divmod, head_divmod;
- int* const tile_count_semaphore;
- };
- static Params
- to_underlying_arguments(Arguments const& args) {
- // return {args.num_blocks_m * args.num_head * args.num_batch,
- // cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head),
- // args.tile_count_semaphore};
- return {args.num_blocks_m * args.num_splits * args.num_head * args.num_batch,
- cutlass::FastDivmod(args.num_blocks_m),
- cutlass::FastDivmod(args.num_splits),
- cutlass::FastDivmod(args.num_head),
- args.tile_count_semaphore};
- }
- static dim3
- get_grid_dim(Arguments const& args, int num_sm) {
- return {uint32_t(num_sm)};
- }
- struct WorkTileInfo {
- int tile_idx;
- CUTLASS_DEVICE
- bool
- is_valid(Params const& params) const {
- return tile_idx < params.total_blocks;
- }
- CUTLASS_DEVICE
- cute::tuple<int32_t, int32_t, int32_t, int32_t>
- get_block_coord(Params const& params) const {
- int m_block, split_idx, bidh, bidb;
- if constexpr(!Is_split) {
- bidb = params.head_divmod.divmod(bidh,
- params.m_block_divmod.divmod(m_block, tile_idx));
- return {m_block, 1, bidh, bidb};
- } else {
- bidb = params.head_divmod.divmod(bidh,
- params.split_divmod.divmod(split_idx,
- params.m_block_divmod.divmod(m_block, tile_idx)));
- return {m_block, split_idx, bidh, bidb};
- }
- }
- };
- CUTLASS_DEVICE
- DynamicPersistentTileScheduler(int* tile_count_smem_) : tile_count_smem(tile_count_smem_) {};
- CUTLASS_DEVICE
- WorkTileInfo
- get_initial_work() const {
- return {int(blockIdx.x)};
- }
- CUTLASS_DEVICE
- void
- init_consumer() const {
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
- }
- CUTLASS_DEVICE
- void
- prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
- if (threadIdx.x % NumProducerThreads == 0) {
- current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
- }
- }
- CUTLASS_DEVICE
- void
- broadcast_next_work(WorkTileInfo& current_work) const {
- cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
- if (threadIdx.x % NumProducerThreads == 0) {
- *tile_count_smem = current_work.tile_idx;
- }
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
- }
- template<bool IsProducer=false>
- CUTLASS_DEVICE
- WorkTileInfo
- get_next_work(Params const& params, WorkTileInfo const& current_work) const {
- if constexpr (IsProducer && NumProducerThreads == cutlass::NumThreadsPerWarp) {
- // thread 0 already has the right tile_idx, just need to broadcast to the rest of the producer threads (warp 0)
- return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)};
- } else if constexpr (IsProducer && NumProducerThreads == cutlass::NumThreadsPerWarpGroup) {
- // TODO: investigate optimal synchronize
- int tile_idx = *tile_count_smem;
- return {tile_idx};
- } else {
- cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
- int tile_idx = *tile_count_smem;
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
- return {tile_idx};
- }
- }
- };
- } // namespace flash
|