Forráskód Böngészése

[FA3] BF16 forward

Tri Dao 8 hónapja
szülő
commit
74b0761ff7

+ 1 - 1
csrc/cutlass

@@ -1 +1 @@
-Subproject commit fa4f6359069bd4dd6fabd0cda2476dd8e72b3837
+Subproject commit 756c351b4994854b2f8c6dded3821ebbb580876b

+ 2 - 1
hopper/epilogue_fwd_sm90_tma.hpp

@@ -9,6 +9,7 @@
 
 #include "cutlass/gemm/collective/collective_builder.hpp"
 
+#include "named_barrier.hpp"
 #include "utils.h"
 
 namespace flash {
@@ -127,7 +128,7 @@ struct CollectiveEpilogueFwd {
         Tensor taccOsO = smem_thr_copy_O.partition_D(sO);     // ((Atom,AtomNum),PIPE_M,PIPE_N)
 
         // Make sure all WGs have finished reading V
-        cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0 /*id*/);
+        cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);
         cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
         cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
         cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,

+ 0 - 2
hopper/flash.h

@@ -66,8 +66,6 @@ struct Flash_fwd_params : public Qkv_params {
 
     // The dimensions.
     int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
-    cutlass::FastDivmod head_divmod, m_block_divmod;
-    int total_blocks;
 
     // The scaling factors for the kernel.
     float scale_softmax;

+ 19 - 13
hopper/flash_api.cpp

@@ -99,8 +99,6 @@ void set_params_fprop(Flash_fwd_params &params,
     params.d = d;
     params.d_rounded = d_rounded;
 
-    params.head_divmod = cutlass::FastDivmod(int(h));
-
     // Set the different scale values.
     params.scale_softmax = softmax_scale;
     params.scale_softmax_log2 = softmax_scale * M_LOG2E;
@@ -225,12 +223,22 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
     //     run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
     // });
     if (!params.is_e4m3) {
-        if (params.d == 64) {
-            run_mha_fwd_<cutlass::half_t, 64>(params, stream);
-        } else if (params.d == 128) {
-            run_mha_fwd_<cutlass::half_t, 128>(params, stream);
+        if (params.is_bf16) {
+            if (params.d == 64) {
+                run_mha_fwd_<cutlass::bfloat16_t, 64>(params, stream);
+            } else if (params.d == 128) {
+                run_mha_fwd_<cutlass::bfloat16_t, 128>(params, stream);
+            } else {
+                run_mha_fwd_<cutlass::bfloat16_t, 256>(params, stream);
+            }
         } else {
-            run_mha_fwd_<cutlass::half_t, 256>(params, stream);
+            if (params.d == 64) {
+                run_mha_fwd_<cutlass::half_t, 64>(params, stream);
+            } else if (params.d == 128) {
+                run_mha_fwd_<cutlass::half_t, 128>(params, stream);
+            } else {
+                run_mha_fwd_<cutlass::half_t, 256>(params, stream);
+            }
         }
     } else {
         // run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
@@ -250,9 +258,8 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
     TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
 
     auto q_dtype = q.dtype();
-    // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
-    TORCH_CHECK(q_dtype == torch::kFloat16,
-                "FlashAttention only support fp16 data type for now");
+    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
+                "FlashAttention only support fp16 and bf16 data type for now");
     // TODO: will add e4m3 later
     // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
                 // "FlashAttention only support fp16 and bf16 data type");
@@ -278,10 +285,9 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
     const int head_size_og = sizes[3];
     const int seqlen_k = k.size(1);
     const int num_heads_k = k.size(2);
-    TORCH_CHECK(batch_size > 0, "batch size must be postive");
+    TORCH_CHECK(batch_size > 0, "batch size must be positive");
     TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
-    TORCH_CHECK(num_heads == num_heads_k, "We do not support MQA/GQA yet");
 
     TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now");
 
@@ -345,7 +351,7 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
                      /*window_size_left=*/-1,
                      /*window_size_right=*/is_causal ? 0 : -1);
 
-    auto tile_count_semaphore = is_causal ? torch::full({1}, 132, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
+    auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
     params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
 
     if (seqlen_k > 0) {

+ 9 - 0
hopper/flash_fwd_hdim128_bf16_sm90.cu

@@ -0,0 +1,9 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
+}

+ 9 - 0
hopper/flash_fwd_hdim256_bf16_sm90.cu

@@ -0,0 +1,9 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
+}

+ 9 - 0
hopper/flash_fwd_hdim64_bf16_sm90.cu

@@ -0,0 +1,9 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
+}

+ 24 - 40
hopper/flash_fwd_kernel.h

@@ -26,8 +26,7 @@ using namespace cute;
 
 template <typename Ktraits, bool Is_causal, typename TileScheduler>
 __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
-    compute_attn_ws(CUTE_GRID_CONSTANT Flash_fwd_params const params,
-                    CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal>::Params const mainloop_params,
+    compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal>::Params const mainloop_params,
                     CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits>::Params const epilogue_params,
                     CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params
                     ) {
@@ -101,9 +100,6 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
     if (warp_group_idx == 0) {  // Producer
         cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 24 : 32>();
         // cutlass::arch::warpgroup_reg_dealloc<56>();
-        // StaticPersistentTileScheduler scheduler{params.m_block_divmod, params.head_divmod, params.total_blocks};
-        // auto work_tile_info = scheduler.get_current_work();
-        TileScheduler scheduler;
 
         int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
         if (warp_idx_in_warpgroup == 0) {  // Load Q, K, V
@@ -112,20 +108,22 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
 
             int work_idx = 0;
 
-            // auto get_tile_count = [&] () {
-            //     cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
-            //     return shared_storage.tile_count_semaphore;
-            // };
-
-            // while (work_tile_info.is_valid()) {
-            // for (int tile_count = blockIdx.x; tile_count < params.total_blocks; tile_count = get_tile_count()) {
-            // for (int tile_count_semaphore = blockIdx.x; tile_count_semaphore < params.total_blocks; tile_count_semaphore = __shfl_sync(0xffffffff, tile_count_semaphore, 0)) {
-            for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) {
-                int tile_count_semaphore = 0;
-                collective_mainloop.load(params, mainloop_params, scheduler_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
-                                         shared_storage, work_tile_info, work_idx, tile_count_semaphore);
-                // ++work_idx;
-                // work_tile_info = scheduler.fetch_next_work();
+            TileScheduler scheduler(&shared_storage.tile_count_semaphore);
+            for (auto work_tile_info = scheduler.get_initial_work();
+                 work_tile_info.is_valid(scheduler_params);
+                 work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(scheduler_params, work_tile_info)) {
+                auto block_coord = work_tile_info.get_block_coord(scheduler_params);
+                auto [m_block, bidh, bidb] = block_coord;
+
+                int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block);
+                if (Is_causal && n_block_max <= 0) {
+                    scheduler.prefetch_next_work(scheduler_params, work_tile_info);
+                    scheduler.broadcast_next_work(work_tile_info);
+                    continue;
+                }
+                collective_mainloop.load(mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
+                                         shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx);
+                ++work_idx;
             }
             collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v);
         }
@@ -133,44 +131,31 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
         cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 240 : 160>();
         // cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 224 : 160>();
 
+        TileScheduler scheduler(&shared_storage.tile_count_semaphore);
         // Initialize matmul objects.
         typename Ktraits::TiledMma1 tiled_mma1;
 
-        TileScheduler scheduler{};
-
         PipelineState smem_pipe_read_k, smem_pipe_read_v;
-        // We don't need separate variables smem_pip_release_k and smem_pipe_release_v
+        // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
         // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
 
-        auto get_tile_count = [&] () {
-            // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
-            cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
-            return shared_storage.tile_count_semaphore;
-        };
-
         collective_mainloop.mma_init();
+        scheduler.init_consumer();
 
         int work_idx = 0;
         CUTLASS_PRAGMA_NO_UNROLL
-        // for (int work_idx = 0; work_idx * gridDim.x + blockIdx.x < params.total_blocks; ++work_idx) {
-        // for (int tile_count_semaphore = blockIdx.x, work_idx = 0; tile_count_semaphore < params.total_blocks; tile_count_semaphore = get_tile_count()) {
-        for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) {
+        for (auto work_tile_info = scheduler.get_initial_work();
+             work_tile_info.is_valid(scheduler_params);
+             work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
             // Attention output (GEMM-II) accumulator.
             Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
             flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
 
-            // int m_block;
-            // int bidh, bidb;
-            // // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, work_idx * gridDim.x + blockIdx.x));
-            // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore));
-            // cute::tuple<int32_t, int32_t, int32_t> block_coord = {m_block, bidh, bidb};
             auto block_coord = work_tile_info.get_block_coord(scheduler_params);
             auto [m_block, bidh, bidb] = block_coord;
 
             int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block);
             if (Is_causal && n_block_max <= 0) {  // We exit early and write 0 to gO and -inf to gLSE.
-                // Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait
-                cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 7 /*id*/);
                 collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord);
                 continue;
             }
@@ -178,15 +163,14 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
             collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v,
                                     tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage);
                                     // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
-                                    // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, 0, shared_storage);
             collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
                                       threadIdx.x - NumCopyThreads, block_coord);
 
             ++work_idx;
-            // work_tile_info = scheduler.fetch_next_work();
         }
         collective_epilogue.store_tail();
     }
+
 }
 
 } // namespace flash

+ 27 - 15
hopper/flash_fwd_launch_template.h

@@ -8,6 +8,7 @@
 
 #include "cute/tensor.hpp"
 
+#include "cutlass/cutlass.h"
 #include "cutlass/cluster_launch.hpp"
 
 #include "static_switch.h"
@@ -26,8 +27,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
     using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal>;
     using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits>;
-    // using Scheduler = flash::SingleTileScheduler;
-    using Scheduler = flash::StaticPersistentTileScheduler;
+    using Scheduler = std::conditional_t<!Is_causal,
+        flash::StaticPersistentTileScheduler,
+        flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup>>;
+        // flash::SingleTileScheduler>;
     typename CollectiveMainloop::Params mainloop_params =
         CollectiveMainloop::to_underlying_arguments({
             static_cast<Element const*>(params.q_ptr),
@@ -51,32 +54,35 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
 
     int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
     num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
-    typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b};
+    typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b, params.tile_count_semaphore};
     typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
 
     // Get the ptr to kernel function.
     void *kernel;
     kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler>;
     int smem_size = sizeof(typename Kernel_traits::SharedStorage);
-    int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
-    int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
-    int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
+    // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
+    // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
+    // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
     // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
     if (smem_size >= 48 * 1024) {
        C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
     }
 
+    int device;
+    cudaGetDevice(&device);
+    int multiprocessor_count;
+    cudaError status_ = cudaDeviceGetAttribute(
+        &multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
+    if (status_ != cudaSuccess) {
+      C10_CUDA_CHECK(status_);
+    }
+    dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
     static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
-    params.m_block_divmod = cutlass::FastDivmod(num_blocks_m);
-    params.total_blocks = num_blocks_m * params.h * params.b;
-    // dim3 grid_dims(num_blocks_m, params.h, params.b);
-    // dim3 grid_dims(132);
-    dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, 132);
     dim3 block_dims(ctaSize);
     dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
     cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
-    cutlass::launch_kernel_on_cluster(launch_params, kernel, params, mainloop_params, epilogue_params, scheduler_params);
-    // kernel<<<grid_dims, block_dims, smem_size, stream>>>(params, tma_load_Q, tma_load_K, tma_load_V, tma_store_O);
+    cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params);
     C10_CUDA_KERNEL_LAUNCH_CHECK();
 }
 
@@ -92,7 +98,10 @@ template<typename T>
 void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 128;
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, !Is_causal ? 2 : 1, T>, Is_causal>(params, stream);
+        // Only use Cluster if number of tiles along seqlen_q is even
+        BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
+        });
     });
 }
 
@@ -100,6 +109,9 @@ template<typename T>
 void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 256;
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, !Is_causal ? 2 : 1, T>, Is_causal>(params, stream);
+        // Only use Cluster if number of tiles along seqlen_q is even
+        BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
+        });
     });
 }

+ 38 - 62
hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp

@@ -14,6 +14,7 @@
 
 #include "cutlass/gemm/collective/collective_builder.hpp"
 
+#include "named_barrier.hpp"
 #include "utils.h"
 
 namespace flash {
@@ -108,6 +109,7 @@ struct CollectiveMainloopFwd {
     struct Params {
         ShapeQKV const shape_Q;
         ShapeQKV const shape_K;
+        cutlass::FastDivmod qhead_per_khead_divmod;
         TMA_Q tma_load_Q;
         TMA_KV tma_load_K, tma_load_V;
         float const softmax_scale_log2;
@@ -137,7 +139,10 @@ struct CollectiveMainloopFwd {
             SmemLayoutV{}(_, _, _0{}),
             select<1, 2>(TileShape_MNK{}),
             size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
-        return {args.shape_Q, args.shape_K, tma_load_Q, tma_load_K, tma_load_V, args.softmax_scale_log2};
+        return {args.shape_Q, args.shape_K,
+                cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
+                tma_load_Q, tma_load_K, tma_load_V,
+                args.softmax_scale_log2};
     }
 
     /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
@@ -162,46 +167,21 @@ struct CollectiveMainloopFwd {
         return n_block_max;
     }
 
-    template <typename FullParams, typename SchedulerParams, typename SharedStorage, typename WorkTileInfo>
+    template <typename Scheduler, typename SharedStorage>
     CUTLASS_DEVICE void
-    load(FullParams const& params,
-         Params const& mainloop_params,
-         SchedulerParams const& scheduler_params,
+    load(Params const& mainloop_params,
          MainloopPipeline pipeline_k,
          MainloopPipeline pipeline_v,
          PipelineState& smem_pipe_write_k,
          PipelineState& smem_pipe_write_v,
          SharedStorage &shared_storage,
-         WorkTileInfo work_tile_info,
-         int& work_idx,
-         int& tile_count_semaphore
+         Scheduler& scheduler,
+         typename Scheduler::Params const& scheduler_params,
+         typename Scheduler::WorkTileInfo& work_tile_info,
+         cute::tuple<int32_t, int32_t, int32_t> block_coord,
+         int work_idx
          ) {
 
-        static constexpr int kBlockM = get<0>(TileShape_MNK{});
-        static constexpr int kBlockN = get<1>(TileShape_MNK{});
-
-        // int const m_block = work_tile_info.M_idx;
-        // int const bidh = work_tile_info.H_idx;
-        // int const bidb = work_tile_info.B_idx;
-        // int m_block;
-        // int bidh, bidb;
-        // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore));
-        auto [m_block, bidh, bidb] = work_tile_info.get_block_coord(scheduler_params);
-        // if (threadIdx.x == 0) { printf("producer, blockIdx.x = %d, bidb = %d, bidh = %d, m_block = %d\n", blockIdx.x, bidb, bidh, m_block); }
-
-        int n_block_max = get_n_block_max(mainloop_params, m_block);
-        if (Is_causal && n_block_max <= 0) {
-            // Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait
-            cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 7 /*id*/);
-            // if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
-            //     tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1);
-            //     shared_storage.tile_count_semaphore = tile_count_semaphore;
-            // }
-            // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
-            cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
-            return;
-        }
-
         Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
         Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
         Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
@@ -210,13 +190,16 @@ struct CollectiveMainloopFwd {
         Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_K);
         Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_K);
 
+        auto [m_block, bidh, bidb] = block_coord;
+        int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
+
         // Prepare the TMA loads
         uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
         constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
         uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
         Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{}));  // (M, K)
-        Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)
-        Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)
+        Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)
+        Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)
 
         Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
         Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
@@ -235,6 +218,7 @@ struct CollectiveMainloopFwd {
             }
         }
 
+        int n_block_max = get_n_block_max(mainloop_params, m_block);
         int n_block = n_block_max - 1;
 
         int lane_predicate = cute::elect_one_sync();
@@ -246,7 +230,7 @@ struct CollectiveMainloopFwd {
         }
 
         // Wait for the MMA warpgroups to say that smem_q is ready
-        cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
+        cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
 
         if (lane_predicate) {
             shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
@@ -272,22 +256,14 @@ struct CollectiveMainloopFwd {
                 ++smem_pipe_write_v;
             }
         }
-        if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
-            // tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1);
-        }
+        scheduler.prefetch_next_work(scheduler_params, work_tile_info);
         if (lane_predicate) {
             pipeline_v.producer_acquire(smem_pipe_write_v);
             copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
                 tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
             ++smem_pipe_write_v;
         }
-        if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
-            // printf("blockIdx.x = %d, tile_count_semaphore: %d\n", blockIdx.x, tile_count_semaphore);
-            // shared_storage.tile_count_semaphore = tile_count_semaphore;
-        }
-        // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
-        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
-        ++work_idx;
+        scheduler.broadcast_next_work(work_tile_info);
     }
 
     /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
@@ -307,36 +283,36 @@ struct CollectiveMainloopFwd {
     }
 
     CUTLASS_DEVICE void
-    scheduler_barrier_sync() {
+    warp_scheduler_barrier_sync() {
         if constexpr (UseSchedulerBarrier) {
-            cutlass::arch::NamedBarrier::sync(NumMmaThreads, 3 + cutlass::canonical_warp_group_idx() /*id*/);
+            cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
         }
     }
 
     CUTLASS_DEVICE void
-    scheduler_barrier_arrive() {
+    warp_scheduler_barrier_arrive() {
         if constexpr (!UseSchedulerBarrier) { return; }
         static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
         if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
-            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
+            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
         } else {
-            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, cutlass::canonical_warp_group_idx() <= 2 ? 3 + cutlass::canonical_warp_group_idx() + 1 : 3 + cutlass::canonical_warp_group_idx() + 1 - 3  /*id*/);
-            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, cutlass::canonical_warp_group_idx() <= 1 ? 3 + cutlass::canonical_warp_group_idx() + 2 : 3 + cutlass::canonical_warp_group_idx() + 2 - 3  /*id*/);
+            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3)  /*id*/);
+            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3)  /*id*/);
         }
     }
 
     CUTLASS_DEVICE void
     mma_init() {
         // Tell producer (warp 0) that smem_q is ready
-        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
+        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
         if constexpr (!UseSchedulerBarrier) { return; }
         static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
         if (cutlass::canonical_warp_group_idx() > 1) {
-            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + 1 /*id*/);
+            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
         }
         if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
             if (cutlass::canonical_warp_group_idx() > 2) {
-                cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + 2 /*id*/);
+                cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
             }
         }
 
@@ -393,9 +369,9 @@ struct CollectiveMainloopFwd {
 
         Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
         consumer_wait(pipeline_k, smem_pipe_read_k);
-        scheduler_barrier_sync();
+        warp_scheduler_barrier_sync();
         flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
-        scheduler_barrier_arrive();
+        warp_scheduler_barrier_arrive();
         if (work_idx != 0) {
             int lane_predicate = cute::elect_one_sync();
             if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
@@ -443,12 +419,12 @@ struct CollectiveMainloopFwd {
         for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) {
             Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
             consumer_wait(pipeline_k, smem_pipe_read_k);
-            scheduler_barrier_sync();
+            warp_scheduler_barrier_sync();
             flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
             if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); }
             consumer_wait(pipeline_v, smem_pipe_read_v);
             flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
-            scheduler_barrier_arrive();
+            warp_scheduler_barrier_arrive();
             warpgroup_wait<1>();
             pipeline_k.consumer_release(smem_pipe_read_k);  // release K
             Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
@@ -472,12 +448,12 @@ struct CollectiveMainloopFwd {
         for (; n_block > 0; --n_block) {
             Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
             consumer_wait(pipeline_k, smem_pipe_read_k);
-            scheduler_barrier_sync();
+            warp_scheduler_barrier_sync();
             flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
             softmax.rescale_o(tOrO, scores_scale);
             consumer_wait(pipeline_v, smem_pipe_read_v);
             flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
-            scheduler_barrier_arrive();
+            warp_scheduler_barrier_arrive();
             warpgroup_wait<1>();
             pipeline_k.consumer_release(smem_pipe_read_k);  // release K
             // auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
@@ -491,7 +467,7 @@ struct CollectiveMainloopFwd {
             cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
         }
         // Tell warp 0 that smem_q is ready
-        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
+        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
         softmax.rescale_o(tOrO, scores_scale);
         consumer_wait(pipeline_v, smem_pipe_read_v);
         flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);

+ 23 - 0
hopper/named_barrier.hpp

@@ -0,0 +1,23 @@
+/******************************************************************************
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include "cutlass/arch/barrier.h"
+
+namespace flash {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// Enumerates the reserved named barriers to avoid potential conflicts
+enum class FwdNamedBarriers {
+    QueryEmpty = 0,
+    ValueEmpty = 1,
+    TileCountSmemEmpty = 2,
+    TileCountSmemFull = 3,
+    WarpSchedulerWG1 = 4,
+    WarpSchedulerWG2 = 5,
+    WarpSchedulerWG3 = 6,
+};
+
+} // flash

+ 3 - 0
hopper/setup.py

@@ -111,8 +111,11 @@ if not SKIP_CUDA_BUILD:
     sources = [
         "flash_api.cpp",
         "flash_fwd_hdim64_fp16_sm90.cu",
+        "flash_fwd_hdim64_bf16_sm90.cu",
         "flash_fwd_hdim128_fp16_sm90.cu",
+        "flash_fwd_hdim128_bf16_sm90.cu",
         "flash_fwd_hdim256_fp16_sm90.cu",
+        "flash_fwd_hdim256_bf16_sm90.cu",
         "flash_bwd_hdim64_fp16_sm90.cu",
         "flash_bwd_hdim128_fp16_sm90.cu",
         "flash_bwd_hdim256_fp16_sm90.cu",

+ 17 - 17
hopper/test_flash_attn.py

@@ -131,15 +131,18 @@ def attention_ref(
 
 
 
-@pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+# @pytest.mark.parametrize("dtype", [torch.bfloat16])
+@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
+# @pytest.mark.parametrize("mha_type", ["gqa"])
 @pytest.mark.parametrize("causal", [False, True])
-# @pytest.mark.parametrize("causal", [False])
+# @pytest.mark.parametrize("causal", [True])
 # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
 # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
 # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
 # @pytest.mark.parametrize('d', [56, 80])
 @pytest.mark.parametrize("d", [64, 128, 256])
-# @pytest.mark.parametrize("d", [128])
+# @pytest.mark.parametrize("d", [256])
 @pytest.mark.parametrize(
     "seqlen_q,seqlen_k",
     [
@@ -151,6 +154,8 @@ def attention_ref(
         (113, 211),
         (108, 256),
         (256, 512),
+        (384, 256),
+        (640, 128),
         (512, 256),
         (1024, 1024),
         (1023, 1024),
@@ -160,7 +165,7 @@ def attention_ref(
 )
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
 def test_flash_attn_output(
-    seqlen_q, seqlen_k, d, causal, dtype
+    seqlen_q, seqlen_k, d, causal, mha_type, dtype
 ):
     device = "cuda"
     # set seed
@@ -168,16 +173,13 @@ def test_flash_attn_output(
     # batch_size = 40
     # nheads = 16
     batch_size = 9
-    nheads = 4
+    nheads = 6
+    nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
     # batch_size = 1
     # nheads = 1
     q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
-    k = torch.randn(
-        batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
-    )
-    v = torch.randn(
-        batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
-    )
+    k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
+    v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
     out, lse = flash_attn_func(q, k, v, causal=causal)
     out_ref, attn_ref = attention_ref(
         q,
@@ -202,15 +204,15 @@ def test_flash_attn_output(
     # m = qk.amax(-1, keepdim=True)
     # s_tmp = torch.exp((qk - m) / math.sqrt(d))
     # exp_sum = s_tmp.sum(-1)
-    qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
-    lse_ref = torch.logsumexp(qk, dim=-1)
+    # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
+    # lse_ref = torch.logsumexp(qk, dim=-1)
 
     print(f"Output max diff: {(out - out_ref).abs().max().item()}")
     print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
     print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
     print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
-    if not causal:
-        print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
+    # if not causal:
+    #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
     # breakpoint()
 
     # if d <= 128:
@@ -248,5 +250,3 @@ def test_flash_attn_output(
     #     assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
     #     assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
     #     assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
-
-

+ 98 - 119
hopper/tile_scheduler.hpp

@@ -1,112 +1,26 @@
 /******************************************************************************
- * Copyright (c) 2024, Tri Dao.
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  ******************************************************************************/
 
 #pragma once
 
 #include "cutlass/fast_math.h"
+#include "cutlass/arch/barrier.h"
 
-namespace flash {
-
-///////////////////////////////////////////////////////////////////////////////
-
-class StaticPersistentTileSchedulerOld {
-  //
-  // Data members
-  //
-
-private:
-  int current_work_linear_idx_;
-  cutlass::FastDivmod const &m_block_divmod, &head_divmod;
-  int const total_blocks;
+#include "named_barrier.hpp"
 
-public:
-  struct WorkTileInfo {
-    int M_idx = 0;
-    int H_idx = 0;
-    int B_idx = 0;
-    bool is_valid_tile = false;
-
-    CUTLASS_HOST_DEVICE
-    bool
-    is_valid() const {
-      return is_valid_tile;
-    }
-
-    CUTLASS_HOST_DEVICE
-    static WorkTileInfo
-    invalid_work_tile() {
-      return {-1, -1, -1, false};
-    }
-
-  };
-
-public:
-
-  CUTLASS_DEVICE explicit StaticPersistentTileSchedulerOld(cutlass::FastDivmod const &m_block_divmod_,
-                                                        cutlass::FastDivmod const &head_divmod_,
-                                                        int const total_blocks_) :
-    m_block_divmod(m_block_divmod_), head_divmod(head_divmod_), total_blocks(total_blocks_) {
-
-    // MSVC requires protecting use of CUDA-specific nonstandard syntax,
-    // like blockIdx and gridDim, with __CUDA_ARCH__.
-#if defined(__CUDA_ARCH__)
-    // current_work_linear_idx_ = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
-    current_work_linear_idx_ = blockIdx.x;
-#else
-    CUTLASS_ASSERT(false && "This line should never be reached");
-#endif
-  }
-
-  CUTLASS_DEVICE
-  WorkTileInfo
-  get_current_work() const {
-    return get_current_work_for_linear_idx(current_work_linear_idx_);
-  }
-
-  CUTLASS_DEVICE
-  WorkTileInfo
-  get_current_work_for_linear_idx(int linear_idx) const {
-    if (linear_idx >= total_blocks) {
-      return WorkTileInfo::invalid_work_tile();
-    }
-
-    // Map worker's linear index into the CTA tiled problem shape to the corresponding MHB indices
-    int M_idx, H_idx, B_idx;
-    int quotient = m_block_divmod.divmod(M_idx, linear_idx);
-    B_idx = head_divmod.divmod(H_idx, quotient);
-    return {M_idx, H_idx, B_idx, true};
-  }
-
-  CUTLASS_DEVICE
-  void
-  // advance_to_next_work(int advance_count = 1) {
-  advance_to_next_work() {
-    // current_work_linear_idx_ += int(gridDim.x * gridDim.y * gridDim.z);
-    current_work_linear_idx_ += int(gridDim.x);
-  }
-
-  CUTLASS_DEVICE
-  WorkTileInfo
-  fetch_next_work() {
-    WorkTileInfo new_work_tile_info;
-    advance_to_next_work();
-    new_work_tile_info = get_current_work();
-    return new_work_tile_info;
-  }
-
-};
+namespace flash {
 
 ///////////////////////////////////////////////////////////////////////////////
 
-class SingleTileScheduler {
+struct SingleTileScheduler {
 
 public:
 
     // Host side kernel arguments
     struct Arguments {
         int const num_blocks_m, num_head, num_batch;
-        int const* tile_count_semaphore = nullptr;
+        int* const tile_count_semaphore = nullptr;
     };
 
     // Device side kernel params
@@ -140,20 +54,30 @@ public:
             return {M_idx, H_idx, B_idx};
         }
 
-        CUTLASS_DEVICE
-        WorkTileInfo
-        get_next_work(Params const& params) const {
-            return {-1, -1, -1, false};
-        }
-
     };
 
+    CUTLASS_DEVICE
+    SingleTileScheduler(int* tile_count_smem_) { }
+
     CUTLASS_DEVICE
     WorkTileInfo
     get_initial_work() const {
         return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
     }
 
+    CUTLASS_DEVICE
+    void
+    init_consumer() const {}
+
+    CUTLASS_DEVICE
+    void
+    prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
+
+    CUTLASS_DEVICE
+    void
+    broadcast_next_work(WorkTileInfo& current_work) const {}
+
+    template<bool IsProducer=false>
     CUTLASS_DEVICE
     WorkTileInfo
     get_next_work(Params const& params, WorkTileInfo const& current_work) const {
@@ -171,7 +95,7 @@ public:
     // Host side kernel arguments
     struct Arguments {
         int const num_blocks_m, num_head, num_batch;
-        int const* tile_count_semaphore = nullptr;
+        int* const tile_count_semaphore = nullptr;
     };
 
     // Device side kernel params
@@ -210,12 +134,28 @@ public:
 
     };
 
+    CUTLASS_DEVICE
+    StaticPersistentTileScheduler(int* tile_count_smem_) {};
+
     CUTLASS_DEVICE
     WorkTileInfo
     get_initial_work() const {
         return {int(blockIdx.x)};
     }
 
+    CUTLASS_DEVICE
+    void
+    init_consumer() const {}
+
+    CUTLASS_DEVICE
+    void
+    prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
+
+    CUTLASS_DEVICE
+    void
+    broadcast_next_work(WorkTileInfo& current_work) const {}
+
+    template<bool IsProducer=false>
     CUTLASS_DEVICE
     WorkTileInfo
     get_next_work(Params const& params, WorkTileInfo const& current_work) const {
@@ -224,21 +164,25 @@ public:
 
 };
 
+template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup>
 class DynamicPersistentTileScheduler {
 
+protected:
+    int* const tile_count_smem;
+
 public:
 
     // Host side kernel arguments
     struct Arguments {
         int const num_blocks_m, num_head, num_batch;
-        int const* tile_count_semaphore;
+        int* const tile_count_semaphore;
     };
 
     // Device side kernel params
     struct Params {
         int const total_blocks;
         cutlass::FastDivmod const m_block_divmod, head_divmod;
-        int const* tile_count_semaphore;
+        int* const tile_count_semaphore;
     };
 
     static Params
@@ -253,25 +197,27 @@ public:
         return {uint32_t(num_sm)};
     }
 
-    using WorkTileInfo = StaticPersistentTileScheduler::WorkTileInfo;
-    // struct WorkTileInfo {
-    //     int tile_idx;
+    struct WorkTileInfo {
+        int tile_idx;
 
-    //     CUTLASS_DEVICE
-    //     bool
-    //     is_valid(Params const& params) const {
-    //         return tile_idx < params.total_blocks;
-    //     }
+        CUTLASS_DEVICE
+        bool
+        is_valid(Params const& params) const {
+            return tile_idx < params.total_blocks;
+        }
+
+        CUTLASS_DEVICE
+        cute::tuple<int32_t, int32_t, int32_t>
+        get_block_coord(Params const& params) const {
+            int m_block, bidh, bidb;
+            bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
+            return {m_block, bidh, bidb};
+        }
 
-    //     CUTLASS_DEVICE
-    //     cute::tuple<int32_t, int32_t, int32_t>
-    //     get_block_coord(Params const& params) const {
-    //         int m_block, bidh, bidb;
-    //         bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
-    //         return {m_block, bidh, bidb};
-    //     }
+    };
 
-    // };
+    CUTLASS_DEVICE
+    DynamicPersistentTileScheduler(int* tile_count_smem_) : tile_count_smem(tile_count_smem_) {};
 
     CUTLASS_DEVICE
     WorkTileInfo
@@ -279,12 +225,45 @@ public:
         return {int(blockIdx.x)};
     }
 
+    CUTLASS_DEVICE
+    void
+    init_consumer() const {
+        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
+    }
+
+    CUTLASS_DEVICE
+    void
+    prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
+        if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
+            current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
+        }
+    }
+
+    CUTLASS_DEVICE
+    void
+    broadcast_next_work(WorkTileInfo& current_work) const {
+        cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
+        if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
+            *tile_count_smem = current_work.tile_idx;
+        }
+        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
+    }
+
+    template<bool IsProducer=false>
     CUTLASS_DEVICE
     WorkTileInfo
     get_next_work(Params const& params, WorkTileInfo const& current_work) const {
-        return {current_work.tile_idx + int(gridDim.x)};
+        if constexpr (IsProducer) {
+            // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0
+            return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)};
+        } else {
+            cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
+            int tile_idx = *tile_count_smem;
+            cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
+            return {tile_idx};
+        }
     }
 
 };
 
-} // flash
+} // flash