|
@@ -664,7 +664,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
|
|
|
# do_o = (g.float() * out.float()).sum(-1)
|
|
|
# dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
|
|
|
# dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
(dqkv,) = torch.autograd.grad(out, qkv, g)
|
|
|
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
|
|
|
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
|
|
@@ -687,7 +687,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
|
|
|
if not alibi:
|
|
|
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
|
|
|
|
|
|
|
|
@@ -811,7 +811,7 @@ def test_flash_attn_varlen_qkvpacked(
|
|
|
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
(dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
|
|
|
dqkv = dqkv_pad_fn(dqkv_unpad)
|
|
|
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
|
|
@@ -835,7 +835,7 @@ def test_flash_attn_varlen_qkvpacked(
|
|
|
if not alibi:
|
|
|
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
|
|
|
|
|
|
|
|
@@ -1036,7 +1036,7 @@ def test_flash_attn_output(
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
do_o = (g.float() * out.float()).sum(-1)
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
if kvpacked:
|
|
|
(
|
|
|
dq,
|
|
@@ -1092,7 +1092,7 @@ def test_flash_attn_output(
|
|
|
if not alibi:
|
|
|
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
|
|
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
|
|
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
|
@@ -1339,7 +1339,7 @@ def test_flash_attn_varlen_output(
|
|
|
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
if kvpacked:
|
|
|
(
|
|
|
dq_unpad,
|
|
@@ -1398,7 +1398,7 @@ def test_flash_attn_varlen_output(
|
|
|
if not alibi:
|
|
|
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
|
|
|
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
|
|
|
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
|
|
@@ -1476,7 +1476,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
do_o = (g.float() * out.float()).sum(-1)
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
(
|
|
|
dq,
|
|
|
dk,
|
|
@@ -1509,7 +1509,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
|
|
|
# of a Pytorch implementation.
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
|
|
|
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
|
|
|
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
|
|
|
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
|
|
@@ -1625,7 +1625,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
do_o = (g.float() * out.float()).sum(-1)
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
(
|
|
|
dq_unpad,
|
|
|
dk_unpad,
|
|
@@ -1661,7 +1661,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
# of a Pytorch implementation.
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
|
|
|
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
|
|
|
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
|
|
|
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
|
|
@@ -1755,7 +1755,7 @@ def test_flash_attn_splitkv(
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
do_o = (g.float() * out.float()).sum(-1)
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
(
|
|
|
dq,
|
|
|
dk,
|
|
@@ -1789,7 +1789,7 @@ def test_flash_attn_splitkv(
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
|
|
|
|
|
|
mult = 2 if not alibi else 8
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
|
|
|
assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
|
|
|
assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
|
|
@@ -1815,8 +1815,9 @@ def test_flash_attn_splitkv(
|
|
|
# @pytest.mark.parametrize("rotary_interleaved", [False])
|
|
|
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
|
|
|
# @pytest.mark.parametrize("rotary_fraction", [0.0])
|
|
|
-# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
|
|
|
-@pytest.mark.parametrize("paged_kv_block_size", [256, 512])
|
|
|
+@pytest.mark.parametrize("paged_kv_block_size", [None, 256])
|
|
|
+# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
|
|
|
+# @pytest.mark.parametrize("paged_kv_block_size", [256])
|
|
|
@pytest.mark.parametrize("has_batch_idx", [False, True])
|
|
|
# @pytest.mark.parametrize("has_batch_idx", [False])
|
|
|
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
|
|
@@ -1900,12 +1901,13 @@ def test_flash_attn_kvcache(
|
|
|
b=batch_size,
|
|
|
)
|
|
|
k_cache = rearrange(
|
|
|
- k_cache_paged[block_table.flatten()],
|
|
|
+ # pytorch 1.12 doesn't have indexing with int32
|
|
|
+ k_cache_paged[block_table.to(dtype=torch.long).flatten()],
|
|
|
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
|
|
b=batch_size,
|
|
|
)[:, :seqlen_k]
|
|
|
v_cache = rearrange(
|
|
|
- v_cache_paged[block_table.flatten()],
|
|
|
+ v_cache_paged[block_table.to(dtype=torch.long).flatten()],
|
|
|
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
|
|
b=batch_size,
|
|
|
)[:, :seqlen_k]
|
|
@@ -1972,8 +1974,12 @@ def test_flash_attn_kvcache(
|
|
|
cos, sin = None, None
|
|
|
q_ro, k_ro = q, k
|
|
|
# k_cache[:, 64:] = -1
|
|
|
- k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
|
|
|
- v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
|
|
|
+ k_cache_ref = (
|
|
|
+ k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
|
|
|
+ ).clone()
|
|
|
+ v_cache_ref = (
|
|
|
+ v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
|
|
|
+ ).clone()
|
|
|
if new_kv:
|
|
|
update_mask = torch.logical_and(
|
|
|
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
|
|
@@ -2044,16 +2050,20 @@ def test_flash_attn_kvcache(
|
|
|
# of a Pytorch implementation.
|
|
|
if new_kv:
|
|
|
if paged_kv_block_size is None:
|
|
|
- k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
|
|
|
- v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
|
|
|
+ k_cache_select = (
|
|
|
+ k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
|
|
|
+ )
|
|
|
+ v_cache_select = (
|
|
|
+ v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
|
|
|
+ )
|
|
|
else:
|
|
|
k_cache_select = rearrange(
|
|
|
- k_cache_paged[block_table.flatten()],
|
|
|
+ k_cache_paged[block_table.to(dtype=torch.long).flatten()],
|
|
|
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
|
|
b=batch_size,
|
|
|
)[:, :seqlen_k]
|
|
|
v_cache_select = rearrange(
|
|
|
- v_cache_paged[block_table.flatten()],
|
|
|
+ v_cache_paged[block_table.to(dtype=torch.long).flatten()],
|
|
|
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
|
|
b=batch_size,
|
|
|
)[:, :seqlen_k]
|
|
@@ -2104,7 +2114,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
|
|
|
torch.random.manual_seed(42)
|
|
|
out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
|
|
|
g = torch.randn_like(out0)
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
(
|
|
|
dq0,
|
|
|
dk0,
|
|
@@ -2119,7 +2129,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
|
|
|
assert torch.equal(out, out0)
|
|
|
assert torch.equal(lse, lse0)
|
|
|
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
|
|
(
|
|
|
dq,
|
|
|
dk,
|
|
@@ -2326,7 +2336,7 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc
|
|
|
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
|
|
|
for _ in range(50):
|
|
|
dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
|
|
@@ -2414,7 +2424,7 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
|
|
|
)
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
- if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
|
|
|
for _ in range(50):
|
|
|
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
|