|
@@ -14,6 +14,7 @@
|
|
|
|
|
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
|
|
|
|
|
+#include "named_barrier.hpp"
|
|
|
#include "utils.h"
|
|
|
|
|
|
namespace flash {
|
|
@@ -108,6 +109,7 @@ struct CollectiveMainloopFwd {
|
|
|
struct Params {
|
|
|
ShapeQKV const shape_Q;
|
|
|
ShapeQKV const shape_K;
|
|
|
+ cutlass::FastDivmod qhead_per_khead_divmod;
|
|
|
TMA_Q tma_load_Q;
|
|
|
TMA_KV tma_load_K, tma_load_V;
|
|
|
float const softmax_scale_log2;
|
|
@@ -137,7 +139,10 @@ struct CollectiveMainloopFwd {
|
|
|
SmemLayoutV{}(_, _, _0{}),
|
|
|
select<1, 2>(TileShape_MNK{}),
|
|
|
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
|
|
- return {args.shape_Q, args.shape_K, tma_load_Q, tma_load_K, tma_load_V, args.softmax_scale_log2};
|
|
|
+ return {args.shape_Q, args.shape_K,
|
|
|
+ cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
|
|
|
+ tma_load_Q, tma_load_K, tma_load_V,
|
|
|
+ args.softmax_scale_log2};
|
|
|
}
|
|
|
|
|
|
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
|
@@ -162,46 +167,21 @@ struct CollectiveMainloopFwd {
|
|
|
return n_block_max;
|
|
|
}
|
|
|
|
|
|
- template <typename FullParams, typename SchedulerParams, typename SharedStorage, typename WorkTileInfo>
|
|
|
+ template <typename Scheduler, typename SharedStorage>
|
|
|
CUTLASS_DEVICE void
|
|
|
- load(FullParams const& params,
|
|
|
- Params const& mainloop_params,
|
|
|
- SchedulerParams const& scheduler_params,
|
|
|
+ load(Params const& mainloop_params,
|
|
|
MainloopPipeline pipeline_k,
|
|
|
MainloopPipeline pipeline_v,
|
|
|
PipelineState& smem_pipe_write_k,
|
|
|
PipelineState& smem_pipe_write_v,
|
|
|
SharedStorage &shared_storage,
|
|
|
- WorkTileInfo work_tile_info,
|
|
|
- int& work_idx,
|
|
|
- int& tile_count_semaphore
|
|
|
+ Scheduler& scheduler,
|
|
|
+ typename Scheduler::Params const& scheduler_params,
|
|
|
+ typename Scheduler::WorkTileInfo& work_tile_info,
|
|
|
+ cute::tuple<int32_t, int32_t, int32_t> block_coord,
|
|
|
+ int work_idx
|
|
|
) {
|
|
|
|
|
|
- static constexpr int kBlockM = get<0>(TileShape_MNK{});
|
|
|
- static constexpr int kBlockN = get<1>(TileShape_MNK{});
|
|
|
-
|
|
|
- // int const m_block = work_tile_info.M_idx;
|
|
|
- // int const bidh = work_tile_info.H_idx;
|
|
|
- // int const bidb = work_tile_info.B_idx;
|
|
|
- // int m_block;
|
|
|
- // int bidh, bidb;
|
|
|
- // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore));
|
|
|
- auto [m_block, bidh, bidb] = work_tile_info.get_block_coord(scheduler_params);
|
|
|
- // if (threadIdx.x == 0) { printf("producer, blockIdx.x = %d, bidb = %d, bidh = %d, m_block = %d\n", blockIdx.x, bidb, bidh, m_block); }
|
|
|
-
|
|
|
- int n_block_max = get_n_block_max(mainloop_params, m_block);
|
|
|
- if (Is_causal && n_block_max <= 0) {
|
|
|
- // Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait
|
|
|
- cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 7 /*id*/);
|
|
|
- // if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
|
|
|
- // tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1);
|
|
|
- // shared_storage.tile_count_semaphore = tile_count_semaphore;
|
|
|
- // }
|
|
|
- // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
|
|
|
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
|
|
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
|
|
|
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
|
|
@@ -210,13 +190,16 @@ struct CollectiveMainloopFwd {
|
|
|
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_K);
|
|
|
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_K);
|
|
|
|
|
|
+ auto [m_block, bidh, bidb] = block_coord;
|
|
|
+ int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
|
|
|
+
|
|
|
// Prepare the TMA loads
|
|
|
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
|
|
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
|
|
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
|
|
Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
|
|
|
- Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
|
|
|
- Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
|
|
|
+ Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
|
|
|
+ Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
|
|
|
|
|
|
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
|
|
|
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
|
|
@@ -235,6 +218,7 @@ struct CollectiveMainloopFwd {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ int n_block_max = get_n_block_max(mainloop_params, m_block);
|
|
|
int n_block = n_block_max - 1;
|
|
|
|
|
|
int lane_predicate = cute::elect_one_sync();
|
|
@@ -246,7 +230,7 @@ struct CollectiveMainloopFwd {
|
|
|
}
|
|
|
|
|
|
// Wait for the MMA warpgroups to say that smem_q is ready
|
|
|
- cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
|
|
|
+ cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
|
|
|
|
|
|
if (lane_predicate) {
|
|
|
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
|
|
@@ -272,22 +256,14 @@ struct CollectiveMainloopFwd {
|
|
|
++smem_pipe_write_v;
|
|
|
}
|
|
|
}
|
|
|
- if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
|
|
|
- // tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1);
|
|
|
- }
|
|
|
+ scheduler.prefetch_next_work(scheduler_params, work_tile_info);
|
|
|
if (lane_predicate) {
|
|
|
pipeline_v.producer_acquire(smem_pipe_write_v);
|
|
|
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
|
|
|
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
|
|
|
++smem_pipe_write_v;
|
|
|
}
|
|
|
- if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
|
|
|
- // printf("blockIdx.x = %d, tile_count_semaphore: %d\n", blockIdx.x, tile_count_semaphore);
|
|
|
- // shared_storage.tile_count_semaphore = tile_count_semaphore;
|
|
|
- }
|
|
|
- // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
|
|
|
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
|
|
|
- ++work_idx;
|
|
|
+ scheduler.broadcast_next_work(work_tile_info);
|
|
|
}
|
|
|
|
|
|
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
|
@@ -307,36 +283,36 @@ struct CollectiveMainloopFwd {
|
|
|
}
|
|
|
|
|
|
CUTLASS_DEVICE void
|
|
|
- scheduler_barrier_sync() {
|
|
|
+ warp_scheduler_barrier_sync() {
|
|
|
if constexpr (UseSchedulerBarrier) {
|
|
|
- cutlass::arch::NamedBarrier::sync(NumMmaThreads, 3 + cutlass::canonical_warp_group_idx() /*id*/);
|
|
|
+ cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
CUTLASS_DEVICE void
|
|
|
- scheduler_barrier_arrive() {
|
|
|
+ warp_scheduler_barrier_arrive() {
|
|
|
if constexpr (!UseSchedulerBarrier) { return; }
|
|
|
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
|
|
|
if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
|
|
|
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
|
|
|
+ cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
|
|
|
} else {
|
|
|
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads, cutlass::canonical_warp_group_idx() <= 2 ? 3 + cutlass::canonical_warp_group_idx() + 1 : 3 + cutlass::canonical_warp_group_idx() + 1 - 3 /*id*/);
|
|
|
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads, cutlass::canonical_warp_group_idx() <= 1 ? 3 + cutlass::canonical_warp_group_idx() + 2 : 3 + cutlass::canonical_warp_group_idx() + 2 - 3 /*id*/);
|
|
|
+ cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
|
|
|
+ cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
CUTLASS_DEVICE void
|
|
|
mma_init() {
|
|
|
// Tell producer (warp 0) that smem_q is ready
|
|
|
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
|
|
|
+ cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
|
|
|
if constexpr (!UseSchedulerBarrier) { return; }
|
|
|
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
|
|
|
if (cutlass::canonical_warp_group_idx() > 1) {
|
|
|
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + 1 /*id*/);
|
|
|
+ cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
|
|
|
}
|
|
|
if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
|
|
|
if (cutlass::canonical_warp_group_idx() > 2) {
|
|
|
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + 2 /*id*/);
|
|
|
+ cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -393,9 +369,9 @@ struct CollectiveMainloopFwd {
|
|
|
|
|
|
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
|
|
consumer_wait(pipeline_k, smem_pipe_read_k);
|
|
|
- scheduler_barrier_sync();
|
|
|
+ warp_scheduler_barrier_sync();
|
|
|
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
|
|
- scheduler_barrier_arrive();
|
|
|
+ warp_scheduler_barrier_arrive();
|
|
|
if (work_idx != 0) {
|
|
|
int lane_predicate = cute::elect_one_sync();
|
|
|
if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
|
|
@@ -443,12 +419,12 @@ struct CollectiveMainloopFwd {
|
|
|
for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) {
|
|
|
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
|
|
consumer_wait(pipeline_k, smem_pipe_read_k);
|
|
|
- scheduler_barrier_sync();
|
|
|
+ warp_scheduler_barrier_sync();
|
|
|
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
|
|
if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); }
|
|
|
consumer_wait(pipeline_v, smem_pipe_read_v);
|
|
|
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
|
|
- scheduler_barrier_arrive();
|
|
|
+ warp_scheduler_barrier_arrive();
|
|
|
warpgroup_wait<1>();
|
|
|
pipeline_k.consumer_release(smem_pipe_read_k); // release K
|
|
|
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
|
|
@@ -472,12 +448,12 @@ struct CollectiveMainloopFwd {
|
|
|
for (; n_block > 0; --n_block) {
|
|
|
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
|
|
consumer_wait(pipeline_k, smem_pipe_read_k);
|
|
|
- scheduler_barrier_sync();
|
|
|
+ warp_scheduler_barrier_sync();
|
|
|
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
|
|
softmax.rescale_o(tOrO, scores_scale);
|
|
|
consumer_wait(pipeline_v, smem_pipe_read_v);
|
|
|
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
|
|
- scheduler_barrier_arrive();
|
|
|
+ warp_scheduler_barrier_arrive();
|
|
|
warpgroup_wait<1>();
|
|
|
pipeline_k.consumer_release(smem_pipe_read_k); // release K
|
|
|
// auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
|
|
@@ -491,7 +467,7 @@ struct CollectiveMainloopFwd {
|
|
|
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
|
|
|
}
|
|
|
// Tell warp 0 that smem_q is ready
|
|
|
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
|
|
|
+ cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
|
|
|
softmax.rescale_o(tOrO, scores_scale);
|
|
|
consumer_wait(pipeline_v, smem_pipe_read_v);
|
|
|
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|