1
0
Эх сурвалжийг харах

Pass seqused_k to _flash_attn_varlen_forward

Tri Dao 8 сар өмнө
parent
commit
898dd4bbf2

+ 6 - 5
flash_attn/flash_attn_interface.py

@@ -77,12 +77,13 @@ def _flash_attn_varlen_forward(
     dropout_p,
     softmax_scale,
     causal,
-    window_size,
-    softcap,
-    alibi_slopes,
-    return_softmax,
+    window_size=(-1, -1),
+    softcap=0.0,
+    alibi_slopes=None,
+    return_softmax=False,
     block_table=None,
     leftpad_k=None,
+    seqused_k=None,
 ):
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
@@ -93,7 +94,7 @@ def _flash_attn_varlen_forward(
         None,
         cu_seqlens_q,
         cu_seqlens_k,
-        None,
+        seqused_k,
         leftpad_k,
         block_table,
         alibi_slopes,