Browse Source

Merge pull request #1233 from Dao-AILab/ipiszy/local_attn

Add local attention in Hopper FAv3
Ying Zhang 6 months ago
parent
commit
9cafd4ae14

+ 1 - 0
hopper/flash.h

@@ -118,6 +118,7 @@ struct Flash_fwd_params : public Qkv_params {
     bool is_bf16;
     bool is_e4m3;
     bool is_causal;
+    bool is_local;
 
     // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
     // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.

+ 33 - 19
hopper/flash_api.cpp

@@ -130,13 +130,18 @@ void set_params_fprop(Flash_fwd_params &params,
 
     // Causal is the special case where window_size_right == 0 and window_size_left < 0.
     // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
-    params.is_causal = window_size_left < 0 && window_size_right == 0;
-
-    if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
-    if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
+    window_size_left = std::min(int(seqlen_k), window_size_left);
+    window_size_right = std::min(int(seqlen_k), window_size_right);
+    if (window_size_left < 0) { window_size_left = seqlen_k; }
+    if (window_size_right < 0) { window_size_right = seqlen_k; }
     params.window_size_left = window_size_left;
     params.window_size_right = window_size_right;
 
+    params.is_causal = window_size_left == seqlen_k && window_size_right == 0;
+    if ((window_size_left < seqlen_k || window_size_right < seqlen_k) && !params.is_causal) {
+        params.is_local = true;
+    }
+
     #ifdef FLASHATTENTION_DISABLE_LOCAL
         TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
             "This flash attention build does not support local attention.");
@@ -273,7 +278,9 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         c10::optional<at::Tensor> &descale_q_, // 1
         c10::optional<at::Tensor> &descale_k_, // 1
         c10::optional<at::Tensor> &descale_v_, // 1
-        bool is_causal) {
+        bool is_causal,
+        int window_size_left,
+        int window_size_right) {
 
     auto dprops = at::cuda::getCurrentDeviceProperties();
     bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
@@ -350,6 +357,8 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
     const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
     const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
 
+    if (is_causal) { window_size_right = 0; }
+
     // Otherwise the kernel will be launched from cuda:0 device
     // Cast to char to avoid compiler warning about narrowing
     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
@@ -375,8 +384,8 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
                      softmax_lse.data_ptr(),
                      /*p_dropout=*/0.f,
                      softmax_scale,
-                     /*window_size_left=*/-1,
-                     /*window_size_right=*/is_causal ? 0 : -1);
+                     /*window_size_left=*/window_size_left,
+                     /*window_size_right=*/window_size_right);
 
     auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
     params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
@@ -437,7 +446,9 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
                int max_seqlen_q,
                const int max_seqlen_k,
                const float softmax_scale,
-               bool is_causal) {
+               bool is_causal,
+               int window_size_left,
+               int window_size_right) {
 
     auto dprops = at::cuda::getCurrentDeviceProperties();
     bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
@@ -468,10 +479,6 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     const int head_size_og = sizes[2];
     const int num_heads_k = k.size(1);
 
-    int window_size_left = -1;
-    int window_size_right = -1;
-    if (is_causal) { window_size_right = 0; }
-
     void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
 
     const int total_q = q.sizes()[0];
@@ -480,9 +487,6 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
-    if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
-    if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
-
     CHECK_SHAPE(q, total_q, num_heads, head_size_og);
     const int total_k = k.size(0);
     CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
@@ -535,6 +539,8 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
     const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
 
+    if (is_causal) { window_size_right = 0; }
+
     // Otherwise the kernel will be launched from cuda:0 device
     // Cast to char to avoid compiler warning about narrowing
     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
@@ -620,6 +626,8 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         c10::optional<at::Tensor> &dv_,   // batch_size x seqlen_k x num_heads_k x head_size
         const float softmax_scale,
         const bool is_causal,
+        int window_size_left,
+        int window_size_right,
         const bool deterministic) {
 
     #ifdef FLASHATTENTION_DISABLE_BACKWARD
@@ -736,6 +744,8 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         dv_expanded = dv;
     }
 
+    if (is_causal) { window_size_right = 0; }
+
     Flash_bwd_params params;
 
     set_params_dgrad(params,
@@ -759,8 +769,8 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
                      softmax_d.data_ptr(),
                      /*p_dropout=*/0.f,
                      softmax_scale,
-                     /*window_size_left=*/-1,
-                     /*window_size_right=*/is_causal ? 0 : -1,
+                     /*window_size_left=*/window_size_left,
+                     /*window_size_right=*/window_size_right,
                      deterministic);
     params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
 
@@ -811,6 +821,8 @@ mha_varlen_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x
                const int max_seqlen_k,          // max sequence length to choose the kernel
                const float softmax_scale,
                const bool is_causal,
+               int window_size_left,
+               int window_size_right,
                const bool deterministic) {
 
     #ifdef FLASHATTENTION_DISABLE_BACKWARD
@@ -927,6 +939,8 @@ mha_varlen_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x
         dout_padded = dout;
     }
 
+    if (is_causal) { window_size_right = 0; }
+
     // Otherwise the kernel will be launched from cuda:0 device
     // Cast to char to avoid compiler warning about narrowing
     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
@@ -973,8 +987,8 @@ mha_varlen_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x
                      softmax_d.data_ptr(),
                      /*p_dropout=*/0.f,
                      softmax_scale,
-                     /*window_size_left=*/-1,
-                     /*window_size_right=*/is_causal ? 0 : -1,
+                     /*window_size_left=*/window_size_left,
+                     /*window_size_right=*/window_size_right,
                      deterministic);
     params.total_q = total_q;
     params.total_k = total_k;

+ 29 - 5
hopper/flash_attn_interface.py

@@ -14,7 +14,7 @@ import flashattn_hopper_cuda
 def maybe_contiguous(x):
     return x.contiguous() if x is not None and x.stride(-1) != 1 else x
 
-def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descale_k = None, descale_v = None):
+def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q = None, descale_k = None, descale_v = None):
     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
     out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
         q,
@@ -26,6 +26,8 @@ def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descal
         descale_k,
         descale_v,
         causal,
+        window_size[0],
+        window_size[1],
     )
     return out, q, k, v, out_padded, softmax_lse, S_dmask
 
@@ -42,6 +44,7 @@ def _flash_attn_backward(
     dv,
     softmax_scale,
     causal,
+    window_size,
     deterministic=False
 ):
     # dq, dk, dv are allocated by us so they should already be contiguous
@@ -58,6 +61,8 @@ def _flash_attn_backward(
         dv,
         softmax_scale,
         causal,
+        window_size[0],
+        window_size[1],
         deterministic,
     )
     return dq, dk, dv, softmax_d
@@ -72,6 +77,7 @@ def _flash_attn_varlen_forward(
     max_seqlen_k,
     softmax_scale,
     causal,
+    window_size=(-1, -1),
     seqused_q=None,
     seqused_k=None,
 ):
@@ -90,6 +96,8 @@ def _flash_attn_varlen_forward(
         max_seqlen_k,
         softmax_scale,
         causal,
+        window_size[0],
+        window_size[1],
     )
     # if out.isnan().any() or softmax_lse.isnan().any():
     #     breakpoint()
@@ -112,6 +120,7 @@ def _flash_attn_varlen_backward(
     max_seqlen_k,
     softmax_scale,
     causal,
+    window_size,
     deterministic=False,
     seqused_q=None,
     seqused_k=None,
@@ -143,6 +152,8 @@ def _flash_attn_varlen_backward(
         max_seqlen_k,
         softmax_scale,
         causal,
+        window_size[0],
+        window_size[1],
         deterministic,
     )
     # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
@@ -159,6 +170,7 @@ class FlashAttnFunc(torch.autograd.Function):
         v,
         softmax_scale,
         causal,
+        window_size,
         deterministic=False,
         descale_q=None,
         descale_k=None,
@@ -172,6 +184,7 @@ class FlashAttnFunc(torch.autograd.Function):
             v,
             softmax_scale,
             causal,
+            window_size,
             descale_q=descale_q,
             descale_k=descale_k,
             descale_v=descale_v,
@@ -179,6 +192,7 @@ class FlashAttnFunc(torch.autograd.Function):
         ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
         ctx.softmax_scale = softmax_scale
         ctx.causal = causal
+        ctx.window_size = window_size
         ctx.deterministic = deterministic
         return out, softmax_lse
 
@@ -198,12 +212,13 @@ class FlashAttnFunc(torch.autograd.Function):
             dv,
             ctx.softmax_scale,
             ctx.causal,
+            ctx.window_size,
             ctx.deterministic,
         )
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dk = dk[..., : dout.shape[-1]]
         dv = dv[..., : dout.shape[-1]]
-        return dq, dk, dv, None, None, None, None, None, None
+        return dq, dk, dv, None, None, None, None, None, None, None
 
 
 class FlashAttnVarlenFunc(torch.autograd.Function):
@@ -219,6 +234,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         max_seqlen_k,
         softmax_scale,
         causal,
+        window_size,
         deterministic=False,
         seqused_q=None,
         seqused_k=None,
@@ -235,6 +251,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
             max_seqlen_k,
             softmax_scale,
             causal=causal,
+            window_size=window_size,
             seqused_q=seqused_q,
             seqused_k=seqused_k,
         )
@@ -246,6 +263,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         ctx.max_seqlen_k = max_seqlen_k
         ctx.softmax_scale = softmax_scale
         ctx.causal = causal
+        ctx.window_size = window_size
         ctx.deterministic = deterministic
         return out, softmax_lse
 
@@ -269,6 +287,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
             ctx.max_seqlen_k,
             ctx.softmax_scale,
             ctx.causal,
+            ctx.window_size,
             ctx.deterministic,
             seqused_q,
             seqused_k,
@@ -276,7 +295,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dk = dk[..., : dout.shape[-1]]
         dv = dv[..., : dout.shape[-1]]
-        return dq, dk, dv, None, None, None, None, None, None, None, None, None
+        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
 
 
 def flash_attn_func(
@@ -285,6 +304,7 @@ def flash_attn_func(
     v,
     softmax_scale=None,
     causal=False,
+    window_size=(-1, -1),
     deterministic=False,
     descale_q=None,
     descale_k=None,
@@ -347,6 +367,7 @@ def flash_attn_func(
         v,
         softmax_scale,
         causal,
+        window_size,
         deterministic,
         descale_q,
         descale_k,
@@ -364,6 +385,7 @@ def flash_attn_varlen_func(
     max_seqlen_k,
     softmax_scale=None,
     causal=False,
+    window_size=(-1, -1),
     deterministic=False,
     seqused_q=None,
     seqused_k=None,
@@ -397,9 +419,10 @@ def flash_attn_varlen_func(
         softmax_scale: float. The scaling of QK^T before applying softmax.
             Default to 1 / sqrt(headdim).
         causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
-        seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of 
+        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
+        seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
             query and output tokens in each sequence.
-        seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of 
+        seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
             key and value tokens in each sequence.
     Return:
         out: (total, nheads, headdim).
@@ -417,6 +440,7 @@ def flash_attn_varlen_func(
         max_seqlen_k,
         softmax_scale,
         causal,
+        window_size,
         deterministic,
         seqused_q,
         seqused_k,

+ 10 - 9
hopper/flash_bwd_kernel.h

@@ -31,6 +31,7 @@ public:
 
     // Type Aliases
     static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
+    static constexpr bool Is_local = CollectiveMainloop_::Is_local;
     static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
     static constexpr bool Varlen = CollectiveMainloop_::Varlen;
 
@@ -155,6 +156,7 @@ public:
         static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
         static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
         static constexpr int kBlockM = get<0>(TileShape_MNK{});
+        static constexpr int kBlockN = get<1>(TileShape_MNK{});
 
         using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
         using PipelineParams = typename MainloopPipeline::Params;
@@ -218,14 +220,14 @@ public:
                     auto block_coord = work_tile_info.get_block_coord(params.scheduler);
                     auto [n_block, bidh, bidb] = block_coord;
                     if constexpr (Varlen) {
-                        if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) {
+                        if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) {
                             scheduler.prefetch_next_work(params.scheduler, work_tile_info);
                             continue;
                         }
                     }
-                    if constexpr (Is_causal) {
+                    if constexpr (Is_causal || Is_local) {
                         int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
-                        int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
+                        int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
                         if (m_block_min >= m_block_max) {
                             scheduler.prefetch_next_work(params.scheduler, work_tile_info);
                             continue;
@@ -247,11 +249,11 @@ public:
                     auto block_coord = work_tile_info.get_block_coord(params.scheduler);
                     auto [n_block, bidh, bidb] = block_coord;
                     if constexpr (Varlen) {
-                        if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
+                        if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
                     }
                     if constexpr (Is_causal) {
                         int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
-                        int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
+                        int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
                         if (m_block_min >= m_block_max) { continue; }
                     }
                     collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord);
@@ -277,11 +279,11 @@ public:
                 auto block_coord = work_tile_info.get_block_coord(params.scheduler);
                 auto [n_block, bidh, bidb] = block_coord;
                 if constexpr (Varlen) {
-                    if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
+                    if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
                 }
-                if constexpr (Is_causal) {
+                if constexpr (Is_causal || Is_local) {
                     int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
-                    int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
+                    int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
                     if (m_block_min >= m_block_max) {  // We exit early and write 0 to dK and dV
                         collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
                         continue;
@@ -300,7 +302,6 @@ public:
             }
             collective_epilogue.store_tail();
         }
-
     }
 
 };

+ 20 - 12
hopper/flash_bwd_launch_template.h

@@ -20,9 +20,10 @@
 
 using namespace cute;
 
-template <int kHeadDim, int kBlockM, int kBlockN, typename Element, bool Is_causal, bool Varlen, bool Deterministic,
+template <int kHeadDim, int kBlockM, int kBlockN, typename Element, bool Is_causal, bool Is_local, bool Varlen, bool Deterministic,
           bool dKV_swapAB, bool dQ_swapAB, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
 void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
+    static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
     using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
     using ElementAccum = float;
     using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90, /*Clear_dQaccum=*/true, Varlen>;
@@ -57,7 +58,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
     using ClusterShape = cute::Shape<_1, Int<1>, _1>;
     static constexpr int Stages = 2;
     using CollectiveMainloop = flash::CollectiveMainloopBwd<Stages, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
-            Is_causal, Varlen, Deterministic,
+            Is_causal, Is_local, Varlen, Deterministic,
             dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>;
     using CollectiveEpilogue = flash::CollectiveEpilogueBwd<TileShape_MNK, Element, CollectiveMainloop::NumMmaThreads, Varlen>;
     using Scheduler = flash::SingleTileSchedulerBwd;
@@ -88,7 +89,8 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
         params.b,
         params.dq_semaphore,
         params.cu_seqlens_q, params.cu_seqlens_k,
-        params.seqused_q, params.seqused_k
+        params.seqused_q, params.seqused_k,
+        params.window_size_left, params.window_size_right
     };
     typename CollectiveEpilogue::Arguments epilogue_args {
         static_cast<Element*>(params.dk_ptr),
@@ -170,9 +172,11 @@ template<typename T>
 void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 64;
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
-            BOOL_SWITCH(params.deterministic, Deterministic, [&] {
-                run_flash_bwd<Headdim, 128, 128, T, Is_causal, Varlen, Deterministic, false, false, 1, 2, 2>(params, stream);
+        BOOL_SWITCH(params.is_local, Is_local, [&] {
+            BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
+                BOOL_SWITCH(params.deterministic, Deterministic, [&] {
+                    run_flash_bwd<Headdim, 128, 128, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 2>(params, stream);
+                });
             });
         });
     });
@@ -182,9 +186,11 @@ template<typename T>
 void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 96;
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
-            BOOL_SWITCH(params.deterministic, Deterministic, [&] {
-                run_flash_bwd<Headdim, 64, 128, T, Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
+        BOOL_SWITCH(params.is_local, Is_local, [&] {
+            BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
+                BOOL_SWITCH(params.deterministic, Deterministic, [&] {
+                    run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
+                });
             });
         });
     });
@@ -194,9 +200,11 @@ template<typename T>
 void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 128;
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
-            BOOL_SWITCH(params.deterministic, Deterministic, [&] {
-                run_flash_bwd<Headdim, 64, 128, T, Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
+        BOOL_SWITCH(params.is_local, Is_local, [&] {
+            BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
+                BOOL_SWITCH(params.deterministic, Deterministic, [&] {
+                    run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
+                });
             });
         });
     });

+ 14 - 10
hopper/flash_fwd_kernel.h

@@ -24,9 +24,9 @@ namespace flash {
 
 using namespace cute;
 
-template <typename Ktraits, bool Is_causal, typename TileScheduler, typename Seqlen_traits>
+template <typename Ktraits, bool Is_causal, bool Is_local, typename TileScheduler, typename Seqlen_traits>
 __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
-    compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>::Params const mainloop_params,
+    compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, Is_local, Seqlen_traits>::Params const mainloop_params,
                     CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits, Seqlen_traits>::Params const epilogue_params,
                     CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params,
                     Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k
@@ -47,7 +47,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
     // static constexpr int kBlockN = Ktraits::kBlockN;
     // constexpr int kHeadDim = Ktraits::kHeadDim;
 
-    using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>;
+    using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Is_local, Seqlen_traits>;
     using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits>;
 
     using MainloopPipeline = typename Ktraits::MainloopPipeline;
@@ -121,9 +121,11 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
                 if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
                     continue;
                 }
-                int n_block_max = collective_mainloop.get_n_block_max(
+                const int n_block_max = collective_mainloop.get_n_block_max(
+                    mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
+                const int n_block_min = collective_mainloop.get_n_block_min(
                     mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
-                if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) {
+                if ((Is_causal || Is_local || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= n_block_min) {
                     scheduler.prefetch_next_work(scheduler_params, work_tile_info);
                     scheduler.broadcast_next_work(work_tile_info);
                     continue;
@@ -167,15 +169,17 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
             if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
                 continue;
             }
-            int n_block_max = collective_mainloop.get_n_block_max(
+            const int n_block_max = collective_mainloop.get_n_block_max(
+                mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
+            const int n_block_min = collective_mainloop.get_n_block_min(
                 mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
-            if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) {  // We exit early and write 0 to gO and -inf to gLSE.
+            if ((Is_causal || Is_local || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= n_block_min) {  // We exit early and write 0 to gO and -inf to gLSE.
                 collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
                 continue;
             }
 
             collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v,
-                                    tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage,
+                                    tOrO, softmax, n_block_max, n_block_min, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage,
                                     seqlen_traits_q, seqlen_traits_k);
                                     // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
             collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
@@ -190,7 +194,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
 
 template <typename Ktraits, bool Is_causal, typename TileScheduler, typename Seqlen_traits>
 __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
-    compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>::Params const mainloop_params,
+    compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, /*Is_local=*/false, Seqlen_traits>::Params const mainloop_params,
                         CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits, Seqlen_traits>::Params const epilogue_params,
                         CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params,
                         Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k
@@ -215,7 +219,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
     static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128;    
     static constexpr bool Use_max_offset = true;
 
-    using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>;
+    using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, /*Is_local=*/false, Seqlen_traits>;
     using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits>;
 
     using MainloopPipeline = typename Ktraits::MainloopPipeline;

+ 39 - 30
hopper/flash_fwd_launch_template.h

@@ -18,18 +18,19 @@
 #include "utils.h"
 
 
-template<typename Kernel_traits, bool Is_causal, typename Seqlen_traits>
+template<typename Kernel_traits, bool Is_causal, bool Is_local, typename Seqlen_traits>
 void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+    static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
     using Element = typename Kernel_traits::Element;
     using OutputType = typename Kernel_traits::OutputType;
     using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
     using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
 
     // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
-    using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Seqlen_traits>;
+    using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Is_local, Seqlen_traits>;
     using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Seqlen_traits>;
     using Scheduler = std::conditional_t<
-        Seqlen_traits::kUseVarSeqLen, 
+        Seqlen_traits::kUseVarSeqLen || Is_local, 
         flash::SingleTileScheduler,
         std::conditional_t<!Is_causal,
             flash::StaticPersistentTileScheduler,
@@ -60,7 +61,9 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
             params.scale_softmax_log2,
             params.descale_q_ptr,
             params.descale_k_ptr,
-            params.descale_v_ptr
+            params.descale_v_ptr,
+            params.window_size_left,
+            params.window_size_right
         });
     typename CollectiveEpilogue::Params epilogue_params =
         CollectiveEpilogue::to_underlying_arguments({
@@ -85,7 +88,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     if constexpr(cutlass::sizeof_bits_v<Element> == 8)
         kernel = (void *)flash::compute_attn_ws_fp8<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
     else
-        kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
+        kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Is_local, Scheduler, Seqlen_traits>;
     int smem_size = sizeof(typename Kernel_traits::SharedStorage);
     // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
     // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
@@ -115,11 +118,13 @@ template<typename T>
 void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 64;
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
-            run_flash_fwd<
-                Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>, 
-                Is_causal, Seqlen_traits
-            >(params, stream);
+        BOOL_SWITCH(params.is_local, Is_local, [&] {
+            SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
+                run_flash_fwd<
+                    Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>, 
+                    Is_causal, Is_local && !Is_causal, Seqlen_traits
+                >(params, stream);
+            });
         });
     });
 }
@@ -128,13 +133,15 @@ template<typename T>
 void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 128;
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
-            // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
-            BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
-                run_flash_fwd<
-                    Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>, 
-                    Is_causal, Seqlen_traits
-                >(params, stream);
+        BOOL_SWITCH(params.is_local, Is_local, [&] {
+            SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
+                // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
+                BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
+                    run_flash_fwd<
+                        Flash_fwd_kernel_traits<Headdim, 128, (Is_causal || Is_local) ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>, 
+                        Is_causal, Is_local && !Is_causal, Seqlen_traits
+                    >(params, stream);
+                });
             });
         });
     });
@@ -144,13 +151,15 @@ template<typename T>
 void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 256;
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
-            // Only use Cluster if number of tiles along seqlen_q is even
-            BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
-                run_flash_fwd<
-                    Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>, 
-                    Is_causal, Seqlen_traits
-                >(params, stream);
+        BOOL_SWITCH(params.is_local, Is_local, [&] {
+            SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
+                // Only use Cluster if number of tiles along seqlen_q is even
+                BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
+                    run_flash_fwd<
+                        Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>, 
+                        Is_causal, Is_local && !Is_causal, Seqlen_traits
+                    >(params, stream);
+                });
             });
         });
     });
@@ -166,11 +175,11 @@ void run_mha_fwd_hdim64_fp8(Flash_fwd_params &params, cudaStream_t stream) {
     using Seqlen_traits = flash::FixedSeqLenTraits;
     if(params.is_causal) {
         run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
-                        false, 1, T>, /*Is_causal=*/true, Seqlen_traits>(params, stream);
+                        false, 1, T>, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream);
     } else {
         BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
             run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
-                            false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, Seqlen_traits>(params, stream);
+                            false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream);
         });
     }
     // BOOL_SWITCH(params.is_causal, Is_causal, [&] {
@@ -195,11 +204,11 @@ void run_mha_fwd_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) {
     using Seqlen_traits = flash::FixedSeqLenTraits;
     if(params.is_causal) {
         run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
-                        false, 1, T>, /*Is_causal=*/true, Seqlen_traits>(params, stream);
+                        false, 1, T>, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream);
     } else {
         BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
             run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
-                            false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, Seqlen_traits>(params, stream);
+                            false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream);
         });
     }
     // BOOL_SWITCH(params.is_causal, Is_causal, [&] {
@@ -224,11 +233,11 @@ void run_mha_fwd_hdim256_fp8(Flash_fwd_params &params, cudaStream_t stream) {
     using Seqlen_traits = flash::FixedSeqLenTraits;
     if(params.is_causal) {
         run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
-                        false, 1, T>, /*Is_causal=*/true, Seqlen_traits>(params, stream);
+                        false, 1, T>, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream);
     } else {
         BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
             run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
-                            false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, Seqlen_traits>(params, stream);
+                            false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream);
         });
     }
     // BOOL_SWITCH(params.is_causal, Is_causal, [&] {

+ 55 - 12
hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp

@@ -24,7 +24,7 @@ namespace flash {
 using namespace cute;
 
 template <int Stages, class ClusterShape_, class TileShape_MNK_, class Element_, class ElementAccum_, class ArchTag_,
-        bool Is_causal_, bool Varlen_, bool Deterministic,
+        bool Is_causal_, bool Is_local_, bool Varlen_, bool Deterministic,
         bool dKV_swapAB_, bool dQ_swapAB_,
         int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
 struct CollectiveMainloopBwd {
@@ -36,6 +36,7 @@ struct CollectiveMainloopBwd {
     using ElementAccum = ElementAccum_;
     using ArchTag = ArchTag_;
     static constexpr bool Is_causal = Is_causal_;
+    static constexpr bool Is_local = Is_local_;
     static constexpr bool Varlen = Varlen_;
     static constexpr bool SdP_swapAB = true;
     static constexpr bool dKV_swapAB = dKV_swapAB_;
@@ -281,6 +282,8 @@ struct CollectiveMainloopBwd {
         int const* cu_seqlens_k = nullptr;
         int const* seqused_k = nullptr;
         int const* seqused_v = nullptr;
+        int window_size_left;
+        int window_size_right;
     };
 
     // Device side kernel params
@@ -307,6 +310,8 @@ struct CollectiveMainloopBwd {
         int const* cu_seqlens_k = nullptr;
         int const* seqused_q = nullptr;
         int const* seqused_k = nullptr;
+        int window_size_left;
+        int window_size_right;
     };
 
     static Params
@@ -367,7 +372,7 @@ struct CollectiveMainloopBwd {
                 args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum,
                 args.softmax_scale, float(args.softmax_scale * M_LOG2E),
                 args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k,
-                args.seqused_k, args.seqused_v};
+                args.seqused_k, args.seqused_v, args.window_size_left, args.window_size_right};
     }
 
     /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
@@ -412,15 +417,31 @@ struct CollectiveMainloopBwd {
 
     CUTLASS_DEVICE
     int get_m_block_min(Params const& params, int n_block, int bidb) {
-        if constexpr (Is_causal) {
+        static constexpr int kBlockM = get<0>(TileShape_MNK{});
+        static constexpr int kBlockN = get<1>(TileShape_MNK{});        
+        if constexpr (Is_causal || Is_local) {
             int const seqlen_q = get_seqlen_q(params, bidb);
             int const seqlen_k = get_seqlen_k(params, bidb);
-            return std::max(0, (n_block * kBlockN + seqlen_q - seqlen_k) / kBlockM);
+            return std::max(0, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM);
         } else {
             return 0;
         }
     }
 
+    CUTLASS_DEVICE
+    int get_m_block_max(Params const& params, int n_block, int bidb) {
+        static constexpr int kBlockM = get<0>(TileShape_MNK{});
+        static constexpr int kBlockN = get<1>(TileShape_MNK{});        
+        int const seqlen_q = get_seqlen_q(params, bidb);
+        int const seqlen_k = get_seqlen_k(params, bidb);
+        int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
+        if constexpr (Is_local) {
+            return std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM));
+        } else {
+            return m_block_max;
+        }
+    }
+
     template <typename SchedulerPrefetch, typename SharedStorage>
     CUTLASS_DEVICE void
     load(Params const& params,
@@ -491,7 +512,7 @@ struct CollectiveMainloopBwd {
             }
         }
 
-        int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{}));
+        int m_block_max = get_m_block_max(params, n_block, bidb);
         int m_block_min = get_m_block_min(params, n_block, bidb);
         int m_block = m_block_min;
 
@@ -568,7 +589,7 @@ struct CollectiveMainloopBwd {
         Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum);  // (TMA, TMA_M, TMA_K)
         Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K)
 
-        int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{}));
+        int m_block_max = get_m_block_max(params, n_block, bidb);
         int m_block_min = get_m_block_min(params, n_block, bidb);
         int m_block = m_block_min;
         int const num_batch = params.num_batch;
@@ -592,6 +613,15 @@ struct CollectiveMainloopBwd {
             }
             cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/);  // sdQ empty, ready to be written to
         }
+        if constexpr (Is_local && Deterministic) {
+            constexpr int kBlockM = get<0>(TileShape_MNK{});        
+            int const seqlen_q = get_seqlen_q(params, bidb);
+            int const m_block_global_max = cute::ceil_div(seqlen_q, kBlockM);
+            #pragma unroll 2
+            for (; m_block < m_block_global_max; ++m_block) {
+                Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head);
+            }
+        }
     }
 
     CUTLASS_DEVICE void
@@ -678,7 +708,7 @@ struct CollectiveMainloopBwd {
         int const seqlen_q = get_seqlen_q(params, bidb);
         int const seqlen_k = get_seqlen_k(params, bidb);
 
-        int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{}));
+        int m_block_max = get_m_block_max(params, n_block, bidb);
         int m_block_min = get_m_block_min(params, n_block, bidb);
         int m_block = m_block_min;
 
@@ -743,8 +773,8 @@ struct CollectiveMainloopBwd {
                 int causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
                 #pragma unroll
                 for (int i = 0; i < size(tSrS); ++i) {
-                    if (int(get<0>(taccScS(i))) >= std::min(int(get<1>(taccScS(i))) + causal_row_offset,
-                                                            seqlen_k - n_block * kBlockN)) {
+                    if (int(get<0>(taccScS(i))) >= 
+                        std::min(int(get<1>(taccScS(i))) + causal_row_offset, seqlen_k - n_block * kBlockN)) {
                         tSrS(i) = -INFINITY;
                     }
                 }
@@ -800,10 +830,23 @@ struct CollectiveMainloopBwd {
             warpgroup_wait<1>();
             Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{}));
             Tensor taccScS = thread_mma_SdP.partition_C(cS);
-            #pragma unroll
-            for (int i = 0; i < size(tSrS); ++i) {
-                if (int(get<0>(taccScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
+            if constexpr (!Is_local) {
+                #pragma unroll
+                for (int i = 0; i < size(tSrS); ++i) {
+                    if (int(get<0>(taccScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
+                }
+            } else {
+                int local_row_offset_right = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + params.window_size_right;
+                int local_row_offset_left = seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - params.window_size_left;
+                #pragma unroll
+                for (int i = 0; i < size(tSrS); ++i) {
+                    if ((int(get<0>(taccScS(i))) >= std::min(int(get<1>(taccScS(i))) + local_row_offset_right, seqlen_k - n_block * kBlockN)) || 
+                        (int(get<0>(taccScS(i))) < std::max(int(get<1>(taccScS(i))) + local_row_offset_left, 0))) {
+                        tSrS(i) = -INFINITY;
+                    }
+                }
             }
+ 
             // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
             Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
             // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tLSErLSE); }

+ 75 - 20
hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp

@@ -79,7 +79,7 @@ struct SmemTransposeFp8_64x64 {
   }
 };
 
-template <typename Ktraits, bool Is_causal, typename Seqlen_traits>
+template <typename Ktraits, bool Is_causal, bool Is_local, typename Seqlen_traits>
 struct CollectiveMainloopFwd {
 
     using Element = typename Ktraits::Element;
@@ -158,6 +158,8 @@ struct CollectiveMainloopFwd {
         float const* descale_q_ptr;
         float const* descale_k_ptr;
         float const* descale_v_ptr;
+        int window_size_left;
+        int window_size_right;
     };
 
     // Device side kernel params
@@ -173,6 +175,8 @@ struct CollectiveMainloopFwd {
         float const* descale_q_ptr;
         float const* descale_k_ptr;
         float const* descale_v_ptr;
+        int window_size_left;
+        int window_size_right;
     };
 
 
@@ -203,7 +207,8 @@ struct CollectiveMainloopFwd {
                 cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
                 tma_load_Q, tma_load_K, tma_load_V,
                 args.softmax_scale_log2,
-                args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr};
+                args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr,
+                args.window_size_left, args.window_size_right};
     }
 
     /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
@@ -225,13 +230,34 @@ struct CollectiveMainloopFwd {
         int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
         int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);        
         int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
-        if constexpr (Is_causal) {
-            n_block_max = std::min(n_block_max,
-                                   cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q, kBlockN));
+        if constexpr (Is_causal || Is_local) {
+            n_block_max = std::min(
+                n_block_max,
+                cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q + mainloop_params.window_size_right, kBlockN));
         }
         return n_block_max;
     }
 
+    CUTLASS_DEVICE
+    int get_n_block_min(
+          Params const& mainloop_params, int m_block, 
+          const Seqlen_traits& seqlen_traits_q,
+          const Seqlen_traits& seqlen_traits_k
+        ) {
+        static constexpr int kBlockM = get<0>(TileShape_MNK{});
+        static constexpr int kBlockN = get<1>(TileShape_MNK{});        
+        int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
+        int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);        
+        if constexpr (!Is_local) {
+            return 0;
+        } else {
+            return std::max(
+                0, 
+                (m_block * kBlockM + seqlen_k - seqlen_q - mainloop_params.window_size_left) / kBlockN
+            );
+        }
+    }
+
     template <typename Scheduler, typename SharedStorage>
     CUTLASS_DEVICE void
     load(Params const& mainloop_params,
@@ -288,7 +314,8 @@ struct CollectiveMainloopFwd {
             }
         }
 
-        int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
+        const int n_block_min = get_n_block_min(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
+        const int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
         int n_block = n_block_max - 1;
 
         int lane_predicate = cute::elect_one_sync();
@@ -315,7 +342,7 @@ struct CollectiveMainloopFwd {
         if (lane_predicate) {
             // CUTLASS_PRAGMA_NO_UNROLL
             #pragma unroll 2
-            for (; n_block > 0; --n_block) {
+            for (; n_block > n_block_min; --n_block) {
                 pipeline_k.producer_acquire(smem_pipe_write_k);
                 copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
                     tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
@@ -645,6 +672,7 @@ struct CollectiveMainloopFwd {
         FrgTensorO& tOrO,
         Softmax& softmax,
         int n_block_count,
+        int n_block_min,
         int thread_idx,
         int work_idx,
         int m_block,
@@ -706,37 +734,50 @@ struct CollectiveMainloopFwd {
         pipeline_k.consumer_release(smem_pipe_read_k);
         ++smem_pipe_read_k;
 
-        auto col_limit_causal = [&](int row, int n_block) {
-            return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
+        auto col_limit_right = [&](int row, int n_block) {
+            return std::min(
+                seqlen_k - n_block * kBlockN, 
+                row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + mainloop_params.window_size_right
+            );
+        };
+        auto col_limit_left = [&](int row, int n_block) {
+            return std::max(
+                0,
+                row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - mainloop_params.window_size_left
+            );
         };
         {
             Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
             Tensor tScS = threadMma0.partition_C(cS);
             #pragma unroll
             for (int i = 0; i < size(tSrS); ++i) {
-                if constexpr (!Is_causal) {  // Just masking based on col
+                if constexpr (!Is_causal && !Is_local) {  // Just masking based on col
                     if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
                 } else {  // mask based on both row and col
                     // using std::min is faster than doing col >= limit0 or col >= limit1
                     // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the
                     // right hand side can be negative and might be converted to a very large unsigned integer.
-                    if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN,
-                                                        col_limit_causal(int(get<0>(tScS(i))), n_block))) {
+                    if (int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block)) {
                         tSrS(i) = -INFINITY;
+                    } else if constexpr (Is_local) {
+                        if (int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block)) {
+                            tSrS(i) = -INFINITY;
+                        }
                     }
                 }
             }
         }
 
         softmax.template online_softmax</*Is_first=*/true>(tSrS);
+ 
         Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
         Tensor scores_scale = make_fragment_like(softmax.row_max);
         clear(scores_scale);
 
-        constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
+        constexpr int n_masking_steps = (!Is_causal) ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
         // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal
         #pragma unroll
-        for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) {
+        for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > n_block_min; ++masking_step, --n_block) {
             Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
             consumer_wait(pipeline_k, smem_pipe_read_k);
             warp_scheduler_barrier_sync();
@@ -751,7 +792,7 @@ struct CollectiveMainloopFwd {
             Tensor tScS = threadMma0.partition_C(cS);
             #pragma unroll
             for (int i = 0; i < size(tSrS); ++i) {
-                if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block - 1)) {
+                if (int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block - 1)) {
                     tSrS(i) = -INFINITY;
                 }
             }
@@ -765,7 +806,7 @@ struct CollectiveMainloopFwd {
         }
 
         #pragma unroll 1
-        for (; n_block > 0; --n_block) {
+        for (; n_block > n_block_min; --n_block) {
             Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
             consumer_wait(pipeline_k, smem_pipe_read_k);
             warp_scheduler_barrier_sync();
@@ -776,9 +817,24 @@ struct CollectiveMainloopFwd {
             warp_scheduler_barrier_arrive();
             warpgroup_wait<1>();
             pipeline_k.consumer_release(smem_pipe_read_k);  // release K
+
+            if constexpr(Is_local) {
+                Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
+                Tensor tScS = threadMma0.partition_C(cS);
+                #pragma unroll
+                for (int i = 0; i < size(tSrS); ++i) {
+                    if (
+                        int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block - 1) ||
+                        int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block - 1)
+                    ) {
+                        tSrS(i) = -INFINITY;
+                    }
+                }
+            }
             // auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
-            cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
-            softmax.template online_softmax</*Is_first=*/false>(tSrS);
+            cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS), scores_scale);
+            softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/Is_local>(tSrS);
+
             warpgroup_wait<0>();
             pipeline_v.consumer_release(smem_pipe_read_v);  // release V
             ++smem_pipe_read_k;
@@ -791,11 +847,10 @@ struct CollectiveMainloopFwd {
         softmax.rescale_o(tOrO, scores_scale);
         consumer_wait(pipeline_v, smem_pipe_read_v);
         flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
-        cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS), scores_scale);
+        cute::copy(softmax.template finalize</*Check_inf=*/Is_causal || Is_local>(tSrS), scores_scale);
         warpgroup_wait<0>();
         pipeline_v.consumer_release(smem_pipe_read_v);  // release V, otherwise producers will hang
         ++smem_pipe_read_v;
-
         softmax.rescale_o(tOrO, scores_scale);
         return;
     }

+ 1 - 1
hopper/setup.py

@@ -144,7 +144,7 @@ if not SKIP_CUDA_BUILD:
         "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage",  # printing out number of registers
         "-lineinfo",
         "-DCUTLASS_DEBUG_TRACE_LEVEL=0",  # Can toggle for debugging
-        "-DNDEBUG",  # Important, otherwise performance is severely impacted             
+        "-DNDEBUG",  # Important, otherwise performance is severely impacted
     ]
     include_dirs = [
         # Path(this_dir) / "fmha-pipeline",

+ 26 - 12
hopper/test_flash_attn.py

@@ -24,11 +24,14 @@ def print_diffs(out, out_ref):
 
 
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+# @pytest.mark.parametrize("dtype", [torch.float16])
 # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
 @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
 # @pytest.mark.parametrize("mha_type", ["mha"])
 @pytest.mark.parametrize("causal", [False, True])
 # @pytest.mark.parametrize("causal", [True])
+@pytest.mark.parametrize("local", [False, True])
+# @pytest.mark.parametrize("local", [True])
 @pytest.mark.parametrize("deterministic", [False, True])
 # @pytest.mark.parametrize("deterministic", [True])
 # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
@@ -37,7 +40,7 @@ def print_diffs(out, out_ref):
 # @pytest.mark.parametrize('d', [56, 80])
 # @pytest.mark.parametrize("d", [64, 128, 256])
 # @pytest.mark.parametrize("d", [64, 96, 128])
-# @pytest.mark.parametrize("d", [64, 128])
+# @pytest.mark.parametrize("d", [256])
 @pytest.mark.parametrize("d", [64, 128, 256])
 @pytest.mark.parametrize("descale", [1.0])
 # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0])
@@ -65,13 +68,13 @@ def print_diffs(out, out_ref):
 )
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
 def test_flash_attn_output(
-    seqlen_q, seqlen_k, d, causal, deterministic, mha_type, dtype, descale
+    seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale
 ):
     device = "cuda"
     if(dtype == torch.float8_e4m3fn):
         dtype_init = torch.float16
     else:
-        dtype_init = dtype    
+        dtype_init = dtype
     print(dtype)
     # set seed
     torch.random.manual_seed(0)
@@ -83,6 +86,7 @@ def test_flash_attn_output(
     # nheads_kv = 2
     # batch_size = 9
     # nheads = 6
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
     q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
     k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
     v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
@@ -96,7 +100,7 @@ def test_flash_attn_output(
     descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda')
     descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda')
     if(dtype != torch.float8_e4m3fn):
-        out, lse = flash_attn_func(q, k, v, causal=causal, deterministic=deterministic)
+        out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic)
     else:
         out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward(
             q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v
@@ -113,7 +117,7 @@ def test_flash_attn_output(
         q = q * descale_q
         k = k * descale_k
         v = v * descale_v
-        
+
     out_ref, attn_ref = attention_ref(
         q,
         k,
@@ -121,6 +125,7 @@ def test_flash_attn_output(
         None,
         None,
         causal=causal,
+        window_size=window_size,
     )
     out_pt, attn_pt = attention_ref(
         q,
@@ -129,6 +134,7 @@ def test_flash_attn_output(
         None,
         None,
         causal=causal,
+        window_size=window_size,
         upcast=False,
         reorder_ops=True,
     )
@@ -144,9 +150,9 @@ def test_flash_attn_output(
     print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
     print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
     print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
-    
+
     # if not causal:
-    #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")                
+    #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
     # breakpoint()
 
     if d <= 128 and dtype != torch.float8_e4m3fn:
@@ -181,7 +187,7 @@ def test_flash_attn_output(
     # breakpoint()
     if(dtype != torch.float8_e4m3fn):
         assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5
-    else:       
+    else:
         # just test correctness of fp8 kernel w/o further quantization techniques
         assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item()
 
@@ -196,14 +202,16 @@ def test_flash_attn_output(
 @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
 # @pytest.mark.parametrize("mha_type", ["mha"])
 @pytest.mark.parametrize("causal", [False, True])
-# @pytest.mark.parametrize("causal", [False])
+# @pytest.mark.parametrize("causal", [True])
+@pytest.mark.parametrize("local", [False, True])
+# @pytest.mark.parametrize("local", [False])
 @pytest.mark.parametrize("deterministic", [False, True])
 # @pytest.mark.parametrize("deterministic", [False])
 @pytest.mark.parametrize("add_unused_qkv", [False, True])
 # @pytest.mark.parametrize("add_unused_qkv", [True])
 # @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
 # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
-# @pytest.mark.parametrize('d', [128])
+# @pytest.mark.parametrize('d', [256])
 # @pytest.mark.parametrize("d", [64, 128, 256])
 @pytest.mark.parametrize("d", [64, 128])
 # @pytest.mark.parametrize("d", [128])
@@ -233,7 +241,7 @@ def test_flash_attn_output(
 )
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
 def test_flash_attn_varlen_output(
-    seqlen_q, seqlen_k, d, causal, deterministic, add_unused_qkv, mha_type, dtype
+    seqlen_q, seqlen_k, d, causal, local, deterministic, add_unused_qkv, mha_type, dtype
 ):
     if (
         max(seqlen_q, seqlen_k) >= 2048
@@ -245,10 +253,13 @@ def test_flash_attn_varlen_output(
     torch.random.manual_seed(0)
     # batch_size = 1
     # nheads = 1
+    # nheads_kv = 1
     batch_size = 9
     nheads = 6
     nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
- 
+
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
+
     q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
     k = torch.randn(
         batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
@@ -308,6 +319,7 @@ def test_flash_attn_varlen_output(
         deterministic=deterministic,
         seqused_q=seqused_q,
         seqused_k=seqused_k,
+        window_size=window_size,
     )
     out = output_pad_fn(out_unpad)
     if query_unused_mask is not None:
@@ -322,6 +334,7 @@ def test_flash_attn_varlen_output(
         query_padding_mask,
         key_padding_mask,
         causal=causal,
+        window_size=window_size,
     )
     out_pt, attn_pt = attention_ref(
         q,
@@ -330,6 +343,7 @@ def test_flash_attn_varlen_output(
         query_padding_mask,
         key_padding_mask,
         causal=causal,
+        window_size=window_size,
         upcast=False,
         reorder_ops=True,
     )

+ 1 - 1
hopper/tile_scheduler.hpp

@@ -270,4 +270,4 @@ public:
 
 };
 
-} // flash
+} // flash