Browse Source

Don't dispatch to local if window size >= seqlen_k

Tri Dao 1 year ago
parent
commit
0842ec0da4
1 changed files with 20 additions and 5 deletions
  1. 20 5
      csrc/flash_attn/flash_api.cpp

+ 20 - 5
csrc/flash_attn/flash_api.cpp

@@ -260,7 +260,7 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         const float p_dropout,
         const float softmax_scale,
         bool is_causal,
-        const int window_size_left,
+        int window_size_left,
         int window_size_right,
         const bool return_softmax,
         c10::optional<at::Generator> gen_) {
@@ -300,6 +300,9 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
     TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
+    if (window_size_left >= seqlen_k) { window_size_left = -1; }
+    if (window_size_right >= seqlen_k) { window_size_right = -1; }
+
     // causal=true is the same as causal=false in this case
     if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
     if (is_causal) { window_size_right = 0; }
@@ -465,7 +468,7 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
                const float softmax_scale,
                const bool zero_tensors,
                const bool is_causal,
-               const int window_size_left,
+               int window_size_left,
                int window_size_right,
                const bool return_softmax,
                c10::optional<at::Generator> gen_) {
@@ -512,6 +515,9 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
     TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
+    if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
+    if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
+
     CHECK_SHAPE(q, total_q, num_heads, head_size_og);
     CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
     CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
@@ -675,7 +681,7 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         const float p_dropout,         // probability to drop
         const float softmax_scale,
         const bool is_causal,
-        const int window_size_left,
+        int window_size_left,
         int window_size_right,
         const bool deterministic,
         c10::optional<at::Generator> gen_,
@@ -738,6 +744,9 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
 
     TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
 
+    if (window_size_left >= seqlen_k) { window_size_left = -1; }
+    if (window_size_right >= seqlen_k) { window_size_right = -1; }
+
     CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
     CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
     CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
@@ -912,7 +921,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
                const float softmax_scale,
                const bool zero_tensors,
                const bool is_causal,
-               const int window_size_left,
+               int window_size_left,
                int window_size_right,
                const bool deterministic,
                c10::optional<at::Generator> gen_,
@@ -979,6 +988,9 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
 
     TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
 
+    if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
+    if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
+
     CHECK_SHAPE(q, total_q, num_heads, head_size);
     CHECK_SHAPE(k, total_k, num_heads_k, head_size);
     CHECK_SHAPE(v, total_k, num_heads_k, head_size);
@@ -1160,7 +1172,7 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
                 c10::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
                 const float softmax_scale,
                 bool is_causal,
-                const int window_size_left,
+                int window_size_left,
                 int window_size_right,
                 bool is_rotary_interleaved,   // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
                 int num_splits
@@ -1216,6 +1228,9 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
         num_heads = num_heads_k;
     }
 
+    if (window_size_left >= seqlen_k) { window_size_left = -1; }
+    if (window_size_right >= seqlen_k) { window_size_right = -1; }
+
     CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
     CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
     CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);