|
@@ -682,7 +682,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
|
|
|
# do_o = (g.float() * out.float()).sum(-1)
|
|
|
# dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
|
|
|
# dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
|
|
|
- if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
|
|
|
(dqkv,) = torch.autograd.grad(out, qkv, g)
|
|
|
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
|
|
|
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
|
|
@@ -705,7 +705,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
|
|
|
if not alibi:
|
|
|
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
|
|
|
- if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
|
|
|
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
|
|
|
|
|
|
|
|
@@ -829,7 +829,7 @@ def test_flash_attn_varlen_qkvpacked(
|
|
|
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
- if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
|
|
|
(dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
|
|
|
dqkv = dqkv_pad_fn(dqkv_unpad)
|
|
|
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
|
|
@@ -853,7 +853,7 @@ def test_flash_attn_varlen_qkvpacked(
|
|
|
if not alibi:
|
|
|
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
|
|
|
- if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
|
|
|
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
|
|
|
|
|
|
|
|
@@ -866,9 +866,9 @@ def test_flash_attn_varlen_qkvpacked(
|
|
|
@pytest.mark.parametrize("deterministic", [False, True])
|
|
|
# @pytest.mark.parametrize("deterministic", [True])
|
|
|
@pytest.mark.parametrize("alibi", [False, True])
|
|
|
-# @pytest.mark.parametrize("alibi", [True])
|
|
|
+# @pytest.mark.parametrize("alibi", [False])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
|
-# @pytest.mark.parametrize("local", [True])
|
|
|
+# @pytest.mark.parametrize("local", [False])
|
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
|
# @pytest.mark.parametrize("causal", [True])
|
|
|
@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
|
|
@@ -894,7 +894,7 @@ def test_flash_attn_varlen_qkvpacked(
|
|
|
)
|
|
|
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
|
|
-# @pytest.mark.parametrize("dropout_p", [0.17])
|
|
|
+# @pytest.mark.parametrize("dropout_p", [0.0])
|
|
|
@pytest.mark.parametrize("softcap", [0.0, 50.0])
|
|
|
def test_flash_attn_output(
|
|
|
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
|
|
@@ -1066,7 +1066,7 @@ def test_flash_attn_output(
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
do_o = (g.float() * out.float()).sum(-1)
|
|
|
- if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
|
|
|
if kvpacked:
|
|
|
(
|
|
|
dq,
|
|
@@ -1122,10 +1122,10 @@ def test_flash_attn_output(
|
|
|
if not alibi:
|
|
|
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
|
|
|
- if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)):
|
|
|
- assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
|
|
- assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
|
|
- assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
|
|
|
+ assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
|
|
|
+ assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
|
|
|
+ assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("kvpacked", [True, False])
|
|
@@ -1382,7 +1382,7 @@ def test_flash_attn_varlen_output(
|
|
|
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
- if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)):
|
|
|
+ if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)):
|
|
|
if kvpacked:
|
|
|
(
|
|
|
dq_unpad,
|
|
@@ -1441,7 +1441,7 @@ def test_flash_attn_varlen_output(
|
|
|
if not alibi:
|
|
|
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)
|
|
|
|
|
|
- if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
|
|
|
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
|
|
|
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
|
|
|
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
|
|
@@ -1519,43 +1519,41 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
do_o = (g.float() * out.float()).sum(-1)
|
|
|
- if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
- (
|
|
|
- dq,
|
|
|
- dk,
|
|
|
- dv,
|
|
|
- ) = torch.autograd.grad(out, (q, k, v), g)
|
|
|
- (
|
|
|
- dq_ref,
|
|
|
- dk_ref,
|
|
|
- dv_ref,
|
|
|
- ) = torch.autograd.grad(out_ref, (q, k, v), g)
|
|
|
- (
|
|
|
- dq_pt,
|
|
|
- dk_pt,
|
|
|
- dv_pt,
|
|
|
- ) = torch.autograd.grad(out_pt, (q, k, v), g)
|
|
|
- print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
|
|
- print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
|
|
- print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
|
|
- print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
|
|
- print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
|
|
- print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
|
|
- print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
|
|
- print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
|
|
- print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
|
|
- print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
|
|
- print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
|
|
- print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
|
|
+ (
|
|
|
+ dq,
|
|
|
+ dk,
|
|
|
+ dv,
|
|
|
+ ) = torch.autograd.grad(out, (q, k, v), g)
|
|
|
+ (
|
|
|
+ dq_ref,
|
|
|
+ dk_ref,
|
|
|
+ dv_ref,
|
|
|
+ ) = torch.autograd.grad(out_ref, (q, k, v), g)
|
|
|
+ (
|
|
|
+ dq_pt,
|
|
|
+ dk_pt,
|
|
|
+ dv_pt,
|
|
|
+ ) = torch.autograd.grad(out_pt, (q, k, v), g)
|
|
|
+ print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
|
|
+ print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
|
|
+ print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
|
|
+ print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
|
|
+ print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
|
|
+ print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
|
|
+ print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
|
|
+ print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
|
|
+ print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
|
|
+ print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
|
|
+ print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
|
|
+ print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
|
|
|
|
|
# Check that FlashAttention's numerical error is at most twice the numerical error
|
|
|
# of a Pytorch implementation.
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
|
|
|
|
|
|
- if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
- assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
|
|
|
- assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
|
|
|
- assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
|
|
|
+ assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
|
|
|
+ assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
|
|
|
+ assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
@@ -1684,7 +1682,7 @@ def test_flash_attn_varlen_causal(
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
do_o = (g.float() * out.float()).sum(-1)
|
|
|
- test_backward = (d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None
|
|
|
+ test_backward = block_table is None
|
|
|
if test_backward:
|
|
|
(
|
|
|
dq_unpad,
|
|
@@ -1815,44 +1813,42 @@ def test_flash_attn_splitkv(
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
do_o = (g.float() * out.float()).sum(-1)
|
|
|
- if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
- (
|
|
|
- dq,
|
|
|
- dk,
|
|
|
- dv,
|
|
|
- ) = torch.autograd.grad(out, (q, k, v), g)
|
|
|
- (
|
|
|
- dq_ref,
|
|
|
- dk_ref,
|
|
|
- dv_ref,
|
|
|
- ) = torch.autograd.grad(out_ref, (q, k, v), g)
|
|
|
- (
|
|
|
- dq_pt,
|
|
|
- dk_pt,
|
|
|
- dv_pt,
|
|
|
- ) = torch.autograd.grad(out_pt, (q, k, v), g)
|
|
|
- print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
|
|
- print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
|
|
- print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
|
|
- print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
|
|
- print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
|
|
- print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
|
|
- print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
|
|
- print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
|
|
- print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
|
|
- print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
|
|
- print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
|
|
- print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
|
|
+ (
|
|
|
+ dq,
|
|
|
+ dk,
|
|
|
+ dv,
|
|
|
+ ) = torch.autograd.grad(out, (q, k, v), g)
|
|
|
+ (
|
|
|
+ dq_ref,
|
|
|
+ dk_ref,
|
|
|
+ dv_ref,
|
|
|
+ ) = torch.autograd.grad(out_ref, (q, k, v), g)
|
|
|
+ (
|
|
|
+ dq_pt,
|
|
|
+ dk_pt,
|
|
|
+ dv_pt,
|
|
|
+ ) = torch.autograd.grad(out_pt, (q, k, v), g)
|
|
|
+ print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
|
|
+ print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
|
|
+ print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
|
|
+ print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
|
|
+ print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
|
|
+ print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
|
|
+ print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
|
|
+ print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
|
|
+ print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
|
|
+ print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
|
|
+ print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
|
|
+ print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
|
|
|
|
|
# Check that FlashAttention's numerical error is at most twice the numerical error
|
|
|
# of a Pytorch implementation.
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
|
|
|
|
|
|
mult = 2 if not alibi else 8
|
|
|
- if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
- assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
|
|
|
- assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
|
|
|
- assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
|
|
|
+ assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
|
|
|
+ assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
|
|
|
+ assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
|
|
|
|
|
|
|
|
|
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
@@ -2208,7 +2204,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
|
|
|
torch.random.manual_seed(42)
|
|
|
out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
|
|
|
g = torch.randn_like(out0)
|
|
|
- if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
|
|
|
(
|
|
|
dq0,
|
|
|
dk0,
|
|
@@ -2223,7 +2219,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
|
|
|
assert torch.equal(out, out0)
|
|
|
assert torch.equal(lse, lse0)
|
|
|
|
|
|
- if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
|
|
|
(
|
|
|
dq,
|
|
|
dk,
|
|
@@ -2430,13 +2426,12 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc
|
|
|
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
- if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
- dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
|
|
|
- for _ in range(50):
|
|
|
- dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
|
|
|
- assert torch.equal(dv, dv0)
|
|
|
- assert torch.equal(dk, dk0)
|
|
|
- assert torch.equal(dq, dq0)
|
|
|
+ dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
|
|
|
+ for _ in range(50):
|
|
|
+ dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
|
|
|
+ assert torch.equal(dv, dv0)
|
|
|
+ assert torch.equal(dk, dk0)
|
|
|
+ assert torch.equal(dq, dq0)
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
@@ -2518,10 +2513,9 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
|
|
|
)
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
- if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
- dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
|
|
|
- for _ in range(50):
|
|
|
- dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
|
|
|
- assert torch.equal(dv, dv0)
|
|
|
- assert torch.equal(dk, dk0)
|
|
|
- assert torch.equal(dq, dq0)
|
|
|
+ dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
|
|
|
+ for _ in range(50):
|
|
|
+ dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
|
|
|
+ assert torch.equal(dv, dv0)
|
|
|
+ assert torch.equal(dk, dk0)
|
|
|
+ assert torch.equal(dq, dq0)
|