123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464 |
- /******************************************************************************
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
- ******************************************************************************/
- #pragma once
- #include "cutlass/fast_math.h"
- #include "cutlass/arch/barrier.h"
- #include "named_barrier.hpp"
- namespace flash {
- ///////////////////////////////////////////////////////////////////////////////
- template<bool Varlen=false, int kBlock=128>
- class SingleTileScheduler {
- public:
- using SharedStorage = int;
- // Host side kernel arguments
- struct Arguments {
- int const num_blocks, num_head, num_batch;
- int* const tile_count_semaphore = nullptr;
- int* const cu_seqlens = nullptr;
- int* const seqused = nullptr;
- };
- // Device side kernel params
- struct Params {
- int const num_blocks, num_head, num_batch;
- int* const cu_seqlens;
- int* const seqused;
- };
- static Params
- to_underlying_arguments(Arguments const& args) {
- return {args.num_blocks, args.num_head, args.num_batch,
- !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused};
- }
- static dim3
- get_grid_shape(Params const& params, int num_sm) {
- return {uint32_t(params.num_blocks), uint32_t(params.num_head), uint32_t(params.num_batch)};
- }
- struct WorkTileInfo {
- int block_idx = 0;
- int bidh = 0;
- int bidb = 0;
- bool is_valid_tile = false;
- CUTLASS_DEVICE
- bool
- is_valid(Params const& params) const {
- return is_valid_tile;
- }
- CUTLASS_DEVICE
- cute::tuple<int32_t, int32_t, int32_t>
- get_block_coord(Params const& params) const {
- return {block_idx, bidh, bidb};
- }
- };
- CUTLASS_DEVICE
- SingleTileScheduler(SharedStorage* const smem_scheduler) { }
- template<bool IsProducerWarp=false>
- CUTLASS_DEVICE
- WorkTileInfo
- get_initial_work(Params const& params) const {
- WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
- if constexpr (Varlen) {
- work_info.is_valid_tile = work_info.block_idx * kBlock < (params.seqused ? params.seqused[work_info.bidb] : params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb]);
- }
- return work_info;
- }
- CUTLASS_DEVICE
- void
- init_consumer() const {}
- CUTLASS_DEVICE
- void
- prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
- template<bool IsProducerWarp=false>
- CUTLASS_DEVICE
- WorkTileInfo
- get_next_work(Params const& params, WorkTileInfo const& current_work) const {
- return {-1, -1, -1, false};
- }
- };
- ///////////////////////////////////////////////////////////////////////////////
- class StaticPersistentTileScheduler {
- public:
- using SharedStorage = int;
- // Host side kernel arguments
- struct Arguments {
- int const num_blocks, num_head, num_batch;
- int* const tile_count_semaphore = nullptr;
- int* const cu_seqlens = nullptr;
- int* const seqused = nullptr;
- };
- // Device side kernel params
- struct Params {
- int total_blocks;
- cutlass::FastDivmod m_block_divmod, head_divmod;
- };
- static Params
- to_underlying_arguments(Arguments const& args) {
- return {args.num_blocks * args.num_head * args.num_batch,
- cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head)};
- }
- static dim3
- get_grid_shape(Params const& params, int num_sm) {
- return {uint32_t(num_sm)};
- }
- struct WorkTileInfo {
- int tile_idx;
- CUTLASS_DEVICE
- bool
- is_valid(Params const& params) const {
- return tile_idx < params.total_blocks;
- }
- CUTLASS_DEVICE
- cute::tuple<int32_t, int32_t, int32_t>
- get_block_coord(Params const& params) const {
- int block, bidh, bidb;
- bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx));
- return {block, bidh, bidb};
- }
- };
- CUTLASS_DEVICE
- StaticPersistentTileScheduler(SharedStorage* const smem_scheduler) {};
- template<bool IsProducerWarp=false>
- CUTLASS_DEVICE
- WorkTileInfo
- get_initial_work(Params const& params) const {
- return {int(blockIdx.x)};
- }
- CUTLASS_DEVICE
- void
- init_consumer() const {}
- CUTLASS_DEVICE
- void
- prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
- template<bool IsProducerWarp=false>
- CUTLASS_DEVICE
- WorkTileInfo
- get_next_work(Params const& params, WorkTileInfo const& current_work) const {
- return {current_work.tile_idx + int(gridDim.x)};
- }
- };
- template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads=cutlass::NumThreadsPerWarp>
- class DynamicPersistentTileScheduler {
- public:
- using SharedStorage = int;
- protected:
- SharedStorage* const tile_count_smem;
- public:
- // Host side kernel arguments
- struct Arguments {
- int const num_blocks, num_head, num_batch;
- int* const tile_count_semaphore;
- int* const cu_seqlens = nullptr;
- int* const seqused = nullptr;
- };
- // Device side kernel params
- struct Params {
- int const total_blocks;
- cutlass::FastDivmod const m_block_divmod, head_divmod;
- int* const tile_count_semaphore;
- };
- static Params
- to_underlying_arguments(Arguments const& args) {
- return {args.num_blocks * args.num_head * args.num_batch,
- cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head),
- args.tile_count_semaphore};
- }
- static dim3
- get_grid_shape(Params const& params, int num_sm) {
- return {uint32_t(num_sm)};
- }
- struct WorkTileInfo {
- int tile_idx;
- CUTLASS_DEVICE
- bool
- is_valid(Params const& params) const {
- return tile_idx < params.total_blocks;
- }
- CUTLASS_DEVICE
- cute::tuple<int32_t, int32_t, int32_t>
- get_block_coord(Params const& params) const {
- int block, bidh, bidb;
- bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx));
- return {block, bidh, bidb};
- }
- };
- CUTLASS_DEVICE
- DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {};
- template<bool IsProducerWarp=false>
- CUTLASS_DEVICE
- WorkTileInfo
- get_initial_work(Params const& params) const {
- return {int(blockIdx.x)};
- }
- CUTLASS_DEVICE
- void
- init_consumer() const {
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
- }
- CUTLASS_DEVICE
- void
- prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
- if (threadIdx.x % NumProducerThreads == 0) {
- current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
- }
- }
- template<bool IsProducerWarp=false>
- CUTLASS_DEVICE
- WorkTileInfo
- get_next_work(Params const& params, WorkTileInfo const& current_work) const {
- if constexpr (IsProducerWarp) {
- // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0
- int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/);
- cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
- if (threadIdx.x % NumProducerThreads == 0) {
- *tile_count_smem = current_work.tile_idx;
- }
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
- return {new_tile_idx};
- } else {
- cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
- int tile_idx = *tile_count_smem;
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
- return {tile_idx};
- }
- }
- };
- template<int kBlock, int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads=cutlass::NumThreadsPerWarp>
- class VarlenDynamicPersistentTileScheduler {
- public:
- using SharedStorage = int4;
- protected:
- SharedStorage* const work_info_smem;
- public:
- // Host side kernel arguments
- struct Arguments {
- int const num_blocks, num_head, num_batch;
- int* const tile_count_semaphore;
- int* const cu_seqlens;
- int* const seqused;
- };
- // Device side kernel params
- struct Params {
- int num_head, num_batch;
- int* const tile_count_semaphore;
- int* const cu_seqlens;
- int* const seqused;
- };
- static Params
- to_underlying_arguments(Arguments const& args) {
- return {args.num_head, args.num_batch,
- args.tile_count_semaphore, args.cu_seqlens, args.seqused};
- }
- static dim3
- get_grid_shape(Params const& params, int num_sm) {
- return {uint32_t(num_sm)};
- }
- struct WorkTileInfo {
- int tile_idx, block, bidh, bidb;
- CUTLASS_DEVICE
- bool
- is_valid(Params const& params) const {
- // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, bidb, params.num_batch); }
- return bidb < params.num_batch;
- }
- CUTLASS_DEVICE
- cute::tuple<int32_t, int32_t, int32_t>
- get_block_coord(Params const& params) const {
- return {block, bidh, bidb};
- }
- };
- CUTLASS_DEVICE
- VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {};
- CUTLASS_DEVICE
- WorkTileInfo
- tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const {
- auto prefix_sum = [](int val) {
- int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
- CUTLASS_PRAGMA_UNROLL
- for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) {
- int32_t partial_sum = __shfl_up_sync(0xffffffff, val, i);
- if (lane >= i) { val += partial_sum; }
- }
- return val;
- };
- auto get_num_m_blocks = [&](int bidb) {
- int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
- int seqlen;
- if (params.seqused) {
- seqlen = lane + bidb < params.num_batch ? params.seqused[lane + bidb] : 0;
- } else {
- int cur_cu_seqlen = lane + bidb <= params.num_batch ? params.cu_seqlens[lane + bidb] : 0;
- int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
- seqlen = next_cu_seqlen - cur_cu_seqlen;
- }
- return lane + bidb < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1
- ? cute::ceil_div(seqlen, kBlock) : 0;
- };
- int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane
- // Cumulative number of blocks for the next 31 batches
- int num_m_blocks_cumulative = prefix_sum(num_m_blocks);
- // Total number of blocks for the next 31 batches
- int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1);
- int group_end_tile = current_work.tile_idx - current_work.block - current_work.bidh * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes
- int bidb = current_work.bidb;
- // if (blockIdx.x <= 9 && threadIdx.x == 128) {
- // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, mh_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, mh_blocks_in_group);
- // }
- while (group_end_tile <= next_tile_idx) {
- bidb += cutlass::NumThreadsPerWarp - 1;
- if (bidb >= params.num_batch) {
- // if (blockIdx.x <= 9 && threadIdx.x == 128) {
- // printf("Returning early, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, mh_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, mh_blocks_in_group);
- // }
- return {next_tile_idx, 0, 0, params.num_batch};
- }
- num_m_blocks = get_num_m_blocks(bidb);
- num_m_blocks_cumulative = prefix_sum(num_m_blocks);
- m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1);
- group_end_tile += m_blocks_in_group * params.num_head;
- // if (blockIdx.x <= 9 && threadIdx.x == 128) {
- // printf("Bottom of while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, mh_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, mh_blocks_in_group);
- // }
- }
- int group_start_tile = group_end_tile - m_blocks_in_group * params.num_head;
- // The next problem to process is the first one that does not have ending tile position
- // that is greater than or equal to tile index.
- int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx));
- bidb += batch_idx_in_group;
- num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group);
- int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head;
- int bidh = mh_block / num_m_blocks;
- int block = mh_block - bidh * num_m_blocks;
- // if (blockIdx.x <= 9 && threadIdx.x == 128) {
- // printf("blockIdx.x = %d, threadIdx.x = %d, num_mh_blocks = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, mh_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, num_mh_blocks, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, mh_blocks_in_group, mh_block, bidh, block);
- // }
- return {next_tile_idx, block, bidh, bidb};
- }
- template<bool IsProducerWarp=false>
- CUTLASS_DEVICE
- WorkTileInfo
- get_initial_work(Params const& params) const {
- if constexpr (IsProducerWarp) {
- WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0});
- if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
- *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb);
- }
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
- return work_info;
- } else {
- return get_next_work<false>(params, {0, 0, 0, 0});
- }
- }
- CUTLASS_DEVICE
- void
- init_consumer() const {
- // Don't arrive at the TileCountSmemEmpty barrier here, because get_initial_work will do that
- }
- CUTLASS_DEVICE
- void
- prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
- if (threadIdx.x % NumProducerThreads == 0) {
- current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
- }
- }
- template<bool IsProducerWarp=false>
- CUTLASS_DEVICE
- WorkTileInfo
- get_next_work(Params const& params, WorkTileInfo const& current_work) const {
- if constexpr (IsProducerWarp) {
- // thread 0 has the next tile_idx, just need to broadcast to the rest of warp 0
- int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/);
- WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb};
- work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info);
- cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
- if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
- *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb);
- }
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
- return work_info;
- } else {
- cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
- int4 work_info = *work_info_smem;
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
- return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w};
- }
- }
- };
- } // flash
|