@@ -77,12 +77,13 @@ def _flash_attn_varlen_forward(
dropout_p,
softmax_scale,
causal,
- window_size,
- softcap,
- alibi_slopes,
- return_softmax,
+ window_size=(-1, -1),
+ softcap=0.0,
+ alibi_slopes=None,
+ return_softmax=False,
block_table=None,
leftpad_k=None,
+ seqused_k=None,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
@@ -93,7 +94,7 @@ def _flash_attn_varlen_forward(
None,
cu_seqlens_q,
cu_seqlens_k,
- None,
+ seqused_k,
leftpad_k,
block_table,
alibi_slopes,