|
@@ -44,7 +44,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
|
|
|
|
|
|
|
|
|
def _flash_attn_forward(
|
|
|
- q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
|
|
|
+ q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
|
|
|
):
|
|
|
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
|
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
|
@@ -59,6 +59,7 @@ def _flash_attn_forward(
|
|
|
causal,
|
|
|
window_size[0],
|
|
|
window_size[1],
|
|
|
+ softcap,
|
|
|
return_softmax,
|
|
|
None,
|
|
|
)
|
|
@@ -123,6 +124,7 @@ def _flash_attn_backward(
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
window_size,
|
|
|
+ softcap,
|
|
|
alibi_slopes,
|
|
|
deterministic,
|
|
|
rng_state=None,
|
|
@@ -151,6 +153,7 @@ def _flash_attn_backward(
|
|
|
causal,
|
|
|
window_size[0],
|
|
|
window_size[1],
|
|
|
+ softcap,
|
|
|
deterministic,
|
|
|
None,
|
|
|
rng_state,
|
|
@@ -176,6 +179,7 @@ def _flash_attn_varlen_backward(
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
window_size,
|
|
|
+ softcap,
|
|
|
alibi_slopes,
|
|
|
deterministic,
|
|
|
rng_state=None,
|
|
@@ -209,6 +213,7 @@ def _flash_attn_varlen_backward(
|
|
|
causal,
|
|
|
window_size[0],
|
|
|
window_size[1],
|
|
|
+ softcap,
|
|
|
deterministic,
|
|
|
None,
|
|
|
rng_state,
|
|
@@ -227,6 +232,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
window_size,
|
|
|
+ softcap,
|
|
|
alibi_slopes,
|
|
|
deterministic,
|
|
|
return_softmax,
|
|
@@ -241,6 +247,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|
|
softmax_scale,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
+ softcap=softcap,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
|
)
|
|
@@ -249,6 +256,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|
|
ctx.softmax_scale = softmax_scale
|
|
|
ctx.causal = causal
|
|
|
ctx.window_size = window_size
|
|
|
+ ctx.softcap = softcap
|
|
|
ctx.alibi_slopes = alibi_slopes
|
|
|
ctx.deterministic = deterministic
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
@@ -272,6 +280,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|
|
ctx.softmax_scale,
|
|
|
ctx.causal,
|
|
|
ctx.window_size,
|
|
|
+ ctx.softcap,
|
|
|
ctx.alibi_slopes,
|
|
|
ctx.deterministic,
|
|
|
rng_state=rng_state,
|
|
@@ -433,6 +442,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
window_size,
|
|
|
+ softcap,
|
|
|
alibi_slopes,
|
|
|
deterministic,
|
|
|
return_softmax,
|
|
@@ -451,6 +461,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
|
|
softmax_scale,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
+ softcap=softcap,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
|
block_table=None,
|
|
@@ -464,6 +475,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
|
|
ctx.softmax_scale = softmax_scale
|
|
|
ctx.causal = causal
|
|
|
ctx.window_size = window_size
|
|
|
+ ctx.softcap = softcap
|
|
|
ctx.alibi_slopes = alibi_slopes
|
|
|
ctx.deterministic = deterministic
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
@@ -492,6 +504,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
|
|
ctx.softmax_scale,
|
|
|
ctx.causal,
|
|
|
ctx.window_size,
|
|
|
+ ctx.softcap,
|
|
|
ctx.alibi_slopes,
|
|
|
ctx.deterministic,
|
|
|
rng_state=rng_state,
|
|
@@ -512,6 +525,7 @@ class FlashAttnFunc(torch.autograd.Function):
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
window_size,
|
|
|
+ softcap,
|
|
|
alibi_slopes,
|
|
|
deterministic,
|
|
|
return_softmax,
|
|
@@ -526,6 +540,7 @@ class FlashAttnFunc(torch.autograd.Function):
|
|
|
softmax_scale,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
+ softcap=softcap,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
|
)
|
|
@@ -534,6 +549,7 @@ class FlashAttnFunc(torch.autograd.Function):
|
|
|
ctx.softmax_scale = softmax_scale
|
|
|
ctx.causal = causal
|
|
|
ctx.window_size = window_size
|
|
|
+ ctx.softcap = softcap
|
|
|
ctx.alibi_slopes = alibi_slopes
|
|
|
ctx.deterministic = deterministic
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
@@ -556,6 +572,7 @@ class FlashAttnFunc(torch.autograd.Function):
|
|
|
ctx.softmax_scale,
|
|
|
ctx.causal,
|
|
|
ctx.window_size,
|
|
|
+ ctx.softcap
|
|
|
ctx.alibi_slopes,
|
|
|
ctx.deterministic,
|
|
|
rng_state=rng_state,
|
|
@@ -581,6 +598,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
window_size,
|
|
|
+ softcap,
|
|
|
alibi_slopes,
|
|
|
deterministic,
|
|
|
return_softmax,
|
|
@@ -600,6 +618,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
|
softmax_scale,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
+ softcap=softcap,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
|
block_table=block_table,
|
|
@@ -613,6 +632,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
|
ctx.softmax_scale = softmax_scale
|
|
|
ctx.causal = causal
|
|
|
ctx.window_size = window_size
|
|
|
+ ctx.softcap = softcap
|
|
|
ctx.alibi_slopes = alibi_slopes
|
|
|
ctx.deterministic = deterministic
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
@@ -639,6 +659,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
|
ctx.softmax_scale,
|
|
|
ctx.causal,
|
|
|
ctx.window_size,
|
|
|
+ ctx.softcap
|
|
|
ctx.alibi_slopes,
|
|
|
ctx.deterministic,
|
|
|
rng_state=rng_state,
|
|
@@ -655,6 +676,7 @@ def flash_attn_qkvpacked_func(
|
|
|
softmax_scale=None,
|
|
|
causal=False,
|
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
|
+ softcap=0.0, # <=0.0 means deactivate
|
|
|
alibi_slopes=None,
|
|
|
deterministic=False,
|
|
|
return_attn_probs=False,
|
|
@@ -676,6 +698,7 @@ def flash_attn_qkvpacked_func(
|
|
|
Default to 1 / sqrt(headdim).
|
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
|
+ softcap: float. Anything > 0 activates softcapping attention.
|
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
|
|
|
the attention score of query i and key j.
|
|
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
|
@@ -698,6 +721,7 @@ def flash_attn_qkvpacked_func(
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
window_size,
|
|
|
+ softcapping,
|
|
|
alibi_slopes,
|
|
|
deterministic,
|
|
|
return_attn_probs,
|
|
@@ -711,6 +735,7 @@ def flash_attn_kvpacked_func(
|
|
|
softmax_scale=None,
|
|
|
causal=False,
|
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
|
+ softcap=0.0, # 0.0 means deactivated
|
|
|
alibi_slopes=None,
|
|
|
deterministic=False,
|
|
|
return_attn_probs=False,
|
|
@@ -748,6 +773,7 @@ def flash_attn_kvpacked_func(
|
|
|
Default to 1 / sqrt(headdim).
|
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
|
+ softcap: float. Anything > 0 activates softcapping attention.
|
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
|
|
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
|
|
is added to the attention score of query i and key j.
|
|
@@ -772,6 +798,7 @@ def flash_attn_kvpacked_func(
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
window_size,
|
|
|
+ softcap,
|
|
|
alibi_slopes,
|
|
|
deterministic,
|
|
|
return_attn_probs,
|
|
@@ -786,6 +813,7 @@ def flash_attn_func(
|
|
|
softmax_scale=None,
|
|
|
causal=False,
|
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
|
+ softcap=0.0, # 0.0 means deactivated
|
|
|
alibi_slopes=None,
|
|
|
deterministic=False,
|
|
|
return_attn_probs=False,
|
|
@@ -846,6 +874,7 @@ def flash_attn_func(
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
window_size,
|
|
|
+ softcap,
|
|
|
alibi_slopes,
|
|
|
deterministic,
|
|
|
return_attn_probs,
|
|
@@ -860,6 +889,7 @@ def flash_attn_varlen_qkvpacked_func(
|
|
|
softmax_scale=None,
|
|
|
causal=False,
|
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
|
+ softcap=0.0, # 0.0 means deactivated
|
|
|
alibi_slopes=None,
|
|
|
deterministic=False,
|
|
|
return_attn_probs=False,
|
|
@@ -884,6 +914,7 @@ def flash_attn_varlen_qkvpacked_func(
|
|
|
Default to 1 / sqrt(headdim).
|
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
|
+ softcap: float. Anything > 0 activates softcapping attention.
|
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
|
|
|
is added to the attention score of query i and key j.
|
|
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
|
@@ -908,6 +939,7 @@ def flash_attn_varlen_qkvpacked_func(
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
window_size,
|
|
|
+ softcap,
|
|
|
alibi_slopes,
|
|
|
deterministic,
|
|
|
return_attn_probs,
|
|
@@ -925,6 +957,7 @@ def flash_attn_varlen_kvpacked_func(
|
|
|
softmax_scale=None,
|
|
|
causal=False,
|
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
|
+ softcap=0.0, # 0.0 means deactivated
|
|
|
alibi_slopes=None,
|
|
|
deterministic=False,
|
|
|
return_attn_probs=False,
|
|
@@ -968,6 +1001,7 @@ def flash_attn_varlen_kvpacked_func(
|
|
|
Default to 1 / sqrt(headdim).
|
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
|
+ softcap: float. Anything > 0 activates softcapping attention.
|
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
|
|
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
|
|
is added to the attention score of query i and key j.
|
|
@@ -996,6 +1030,7 @@ def flash_attn_varlen_kvpacked_func(
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
window_size,
|
|
|
+ softcap,
|
|
|
alibi_slopes,
|
|
|
deterministic,
|
|
|
return_attn_probs,
|
|
@@ -1014,6 +1049,7 @@ def flash_attn_varlen_func(
|
|
|
softmax_scale=None,
|
|
|
causal=False,
|
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
|
+ softcap=0.0, # 0.0 means deactivated
|
|
|
alibi_slopes=None,
|
|
|
deterministic=False,
|
|
|
return_attn_probs=False,
|
|
@@ -1056,6 +1092,7 @@ def flash_attn_varlen_func(
|
|
|
Default to 1 / sqrt(headdim).
|
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
|
+ softcap: float. Anything > 0 activates softcapping attention.
|
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
|
|
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
|
|
is added to the attention score of query i and key j.
|
|
@@ -1085,6 +1122,7 @@ def flash_attn_varlen_func(
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
window_size,
|
|
|
+ softcap,
|
|
|
alibi_slopes,
|
|
|
deterministic,
|
|
|
return_attn_probs,
|
|
@@ -1106,6 +1144,7 @@ def flash_attn_with_kvcache(
|
|
|
softmax_scale=None,
|
|
|
causal=False,
|
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
|
+ softcap=0.0, # 0.0 means deactivated
|
|
|
rotary_interleaved=True,
|
|
|
alibi_slopes=None,
|
|
|
num_splits=0,
|
|
@@ -1177,6 +1216,7 @@ def flash_attn_with_kvcache(
|
|
|
Default to 1 / sqrt(headdim).
|
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
|
+ softcap: float. Anything > 0 activates softcapping attention.
|
|
|
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
|
|
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
|
|
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
|
@@ -1226,6 +1266,7 @@ def flash_attn_with_kvcache(
|
|
|
causal,
|
|
|
window_size[0],
|
|
|
window_size[1],
|
|
|
+ softcap,
|
|
|
rotary_interleaved,
|
|
|
num_splits,
|
|
|
)
|