|
@@ -153,6 +153,9 @@ class FlashAttnFunc(torch.autograd.Function):
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
deterministic=False,
|
|
|
+ descale_q=None,
|
|
|
+ descale_k=None,
|
|
|
+ descale_v=None,
|
|
|
):
|
|
|
if softmax_scale is None:
|
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
|
@@ -161,7 +164,10 @@ class FlashAttnFunc(torch.autograd.Function):
|
|
|
k,
|
|
|
v,
|
|
|
softmax_scale,
|
|
|
- causal
|
|
|
+ causal,
|
|
|
+ descale_q=descale_q,
|
|
|
+ descale_k=descale_k,
|
|
|
+ descale_v=descale_v,
|
|
|
)
|
|
|
ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
|
|
|
ctx.softmax_scale = softmax_scale
|
|
@@ -190,7 +196,7 @@ class FlashAttnFunc(torch.autograd.Function):
|
|
|
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
|
|
dk = dk[..., : dout.shape[-1]]
|
|
|
dv = dv[..., : dout.shape[-1]]
|
|
|
- return dq, dk, dv, None, None, None
|
|
|
+ return dq, dk, dv, None, None, None, None, None, None
|
|
|
|
|
|
|
|
|
class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
@@ -265,7 +271,10 @@ def flash_attn_func(
|
|
|
v,
|
|
|
softmax_scale=None,
|
|
|
causal=False,
|
|
|
- deterministic=False
|
|
|
+ deterministic=False,
|
|
|
+ descale_q=None,
|
|
|
+ descale_k=None,
|
|
|
+ descale_v=None,
|
|
|
):
|
|
|
"""dropout_p should be set to 0.0 during evaluation
|
|
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
|
@@ -303,6 +312,9 @@ def flash_attn_func(
|
|
|
is added to the attention score of query i and key j.
|
|
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
|
|
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
|
|
+ descale_q: (1,), fp32. A de-quantization scaling factor for q in fp8 execution.
|
|
|
+ descale_k: (1,), fp32. A de-quantization scaling factor for k in fp8 execution.
|
|
|
+ descale_v: (1,), fp32. A de-quantization scaling factor for v in fp8 execution.
|
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
|
(they might not have the right scaling).
|
|
@@ -322,6 +334,9 @@ def flash_attn_func(
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
deterministic,
|
|
|
+ descale_q,
|
|
|
+ descale_k,
|
|
|
+ descale_v,
|
|
|
)
|
|
|
|
|
|
|