123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 |
- namespace flash {
- using namespace cute;
- template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
- class FlashAttnFwdSm80 {
- public:
-
- using CollectiveMainloop = CollectiveMainloop_;
- using CollectiveEpilogue = CollectiveEpilogue_;
- static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
- static constexpr bool Is_local = CollectiveMainloop::Is_local;
- static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
- static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
- static constexpr bool Varlen = CollectiveMainloop::Varlen;
- static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
- static constexpr bool Split = CollectiveMainloop::Split;
- static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
- static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
- static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
- static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
- static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
- using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
-
- using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
- using TiledMma = typename CollectiveMainloop::TiledMma;
- using ArchTag = typename CollectiveMainloop::ArchTag;
- using MainloopArguments = typename CollectiveMainloop::Arguments;
- using MainloopParams = typename CollectiveMainloop::Params;
-
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
- using EpilogueParams = typename CollectiveEpilogue::Params;
- static_assert(ArchTag::kMinComputeCapability >= 80);
- using TileScheduler = TileScheduler_;
- using TileSchedulerArguments = typename flash::TileSchedulerArguments;
- using TileSchedulerParams = typename TileScheduler::Params;
- static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{}));
- static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{}));
- static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1;
-
-
-
- static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage))
- - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)))
- - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)));
- static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
- struct SharedStorage {
- struct TensorStorage : cute::aligned_struct<128> {
- union {
- struct {
- cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
- typename CollectiveMainloop::TensorStorage mainloop;
- };
-
- typename CollectiveEpilogue::TensorStorage epilogue;
- };
- } tensors;
- alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
- };
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
-
- struct Arguments {
- MainloopArguments mainloop{};
- EpilogueArguments epilogue{};
- cutlass::KernelHardwareInfo hw_info{};
- TileSchedulerArguments scheduler{};
- };
-
- struct Params {
- MainloopParams mainloop{};
- EpilogueParams epilogue{};
- cutlass::KernelHardwareInfo hw_info{};
- TileSchedulerParams scheduler{};
- };
-
-
-
-
- static
- Params
- to_underlying_arguments(Arguments const& args) {
- CUTLASS_TRACE_HOST("to_underlying_arguments():");
-
- int sm_count = args.hw_info.sm_count;
- if (sm_count <= 0) {
- CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
- " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
- sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
- }
- CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
- cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
- return {
- CollectiveMainloop::to_underlying_arguments(args.mainloop),
- CollectiveEpilogue::to_underlying_arguments(args.epilogue),
- hw_info,
- TileScheduler::to_underlying_arguments(args.scheduler)
- };
- }
-
- static dim3
- get_grid_shape(Params const& params) {
- return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor);
- }
- static dim3
- get_block_shape() {
- return dim3(MaxThreadsPerBlock, 1, 1);
- }
- CUTLASS_DEVICE
- void
- operator()(Params const& params, char* smem_buf) {
- static constexpr int kBlockM = get<0>(TileShape_MNK{});
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
- CollectiveMainloop collective_mainloop;
- CollectiveEpilogue collective_epilogue;
- TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
-
- TiledMma tiled_mma;
- scheduler.init_consumer();
- int warp_idx = cutlass::canonical_warp_idx_sync();
- CUTLASS_PRAGMA_NO_UNROLL
- for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work<true>(params.scheduler) : scheduler.template get_initial_work<false>(params.scheduler);
- work_tile_info.is_valid(params.scheduler);
- work_tile_info = warp_idx == 0 ? scheduler.template get_next_work<true>(params.scheduler, work_tile_info) : scheduler.template get_next_work<false>(params.scheduler, work_tile_info)) {
-
- Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{}));
- float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
-
- auto block_coord = work_tile_info.get_block_coord(params.scheduler);
- int const bidb = get<2>(block_coord);
- if constexpr (Is_FP8 && !Has_softcap) {
- int const bidh = get<1>(block_coord);
- int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
- float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
- float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
- softmax_scale_log2 *= q_descale * k_descale;
- }
- flash::Softmax<2 * (2 * kBlockM / NumThreads), !Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
- SeqlenInfo_t seqlen_info{
- bidb,
- get<0>(params.mainloop.shape_Q),
- !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
- get<0>(params.mainloop.shape_K_new),
- params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
- params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
- };
- if constexpr (AppendKV) {
- bool tile_new_valid = collective_mainloop.store_kv_new(
- params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord);
- if (tile_new_valid) { __syncthreads(); }
- }
- bool tile_valid = collective_mainloop.mma(
- params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord,
- shared_storage);
- scheduler.prefetch_next_work(params.scheduler, work_tile_info);
- if (tile_valid) {
-
- collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma,
- threadIdx.x, block_coord);
- } else {
-
-
-
- collective_epilogue.template store_zero<!Split >(params.epilogue, threadIdx.x, block_coord);
- }
- }
- }
- };
- }
|