Browse Source

Add q, k, v descales to FA3 interface (#1210)

* add descale_q/k/v for fp8 fwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix .apply args

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Charlene Yang 6 months ago
parent
commit
bdf733be55
1 changed files with 18 additions and 3 deletions
  1. 18 3
      hopper/flash_attn_interface.py

+ 18 - 3
hopper/flash_attn_interface.py

@@ -153,6 +153,9 @@ class FlashAttnFunc(torch.autograd.Function):
         softmax_scale,
         causal,
         deterministic=False,
+        descale_q=None,
+        descale_k=None,
+        descale_v=None,
     ):
         if softmax_scale is None:
             softmax_scale = q.shape[-1] ** (-0.5)
@@ -161,7 +164,10 @@ class FlashAttnFunc(torch.autograd.Function):
             k,
             v,
             softmax_scale,
-            causal
+            causal,
+            descale_q=descale_q,
+            descale_k=descale_k,
+            descale_v=descale_v,
         )
         ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
         ctx.softmax_scale = softmax_scale
@@ -190,7 +196,7 @@ class FlashAttnFunc(torch.autograd.Function):
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dk = dk[..., : dout.shape[-1]]
         dv = dv[..., : dout.shape[-1]]
-        return dq, dk, dv, None, None, None
+        return dq, dk, dv, None, None, None, None, None, None
 
 
 class FlashAttnVarlenFunc(torch.autograd.Function):
@@ -265,7 +271,10 @@ def flash_attn_func(
     v,
     softmax_scale=None,
     causal=False,
-    deterministic=False
+    deterministic=False,
+    descale_q=None,
+    descale_k=None,
+    descale_v=None,
 ):
     """dropout_p should be set to 0.0 during evaluation
     Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@@ -303,6 +312,9 @@ def flash_attn_func(
             is added to the attention score of query i and key j.
         deterministic: bool. Whether to use the deterministic implementation of the backward pass,
             which is slightly slower and uses more memory. The forward pass is always deterministic.
+        descale_q: (1,), fp32. A de-quantization scaling factor for q in fp8 execution.
+        descale_k: (1,), fp32. A de-quantization scaling factor for k in fp8 execution.
+        descale_v: (1,), fp32. A de-quantization scaling factor for v in fp8 execution.
         return_attn_probs: bool. Whether to return the attention probabilities. This option is for
            testing only. The returned probabilities are not guaranteed to be correct
            (they might not have the right scaling).
@@ -322,6 +334,9 @@ def flash_attn_func(
         softmax_scale,
         causal,
         deterministic,
+        descale_q,
+        descale_k,
+        descale_v,
     )