|
@@ -1668,7 +1668,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
|
|
|
@pytest.mark.parametrize("new_kv", [False, True])
|
|
|
# @pytest.mark.parametrize("new_kv", [True])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
|
-# @pytest.mark.parametrize("local", [True])
|
|
|
+# @pytest.mark.parametrize("local", [False])
|
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
|
# @pytest.mark.parametrize("causal", [True])
|
|
|
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
|
|
@@ -1677,6 +1677,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
|
|
|
# @pytest.mark.parametrize("rotary_interleaved", [False])
|
|
|
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
|
|
|
# @pytest.mark.parametrize("rotary_fraction", [0.0])
|
|
|
+@pytest.mark.parametrize("has_batch_idx", [False, True])
|
|
|
+# @pytest.mark.parametrize("has_batch_idx", [True])
|
|
|
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
|
|
|
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
|
|
|
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
|
@@ -1703,6 +1705,7 @@ def test_flash_attn_kvcache(
|
|
|
seqlen_q,
|
|
|
seqlen_k,
|
|
|
d,
|
|
|
+ has_batch_idx,
|
|
|
rotary_fraction,
|
|
|
rotary_interleaved,
|
|
|
seqlen_new_eq_seqlen_q,
|
|
@@ -1721,6 +1724,7 @@ def test_flash_attn_kvcache(
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
|
batch_size = 2
|
|
|
+ batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
|
|
|
nheads = 6
|
|
|
# rotary_dim must be a multiple of 16, and must be <= d
|
|
|
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
|
|
@@ -1734,8 +1738,8 @@ def test_flash_attn_kvcache(
|
|
|
v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
|
|
|
else:
|
|
|
k, v = None, None
|
|
|
- k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
|
|
- v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
|
|
+ k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
|
|
+ v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
|
|
cache_seqlens = torch.randint(
|
|
|
0,
|
|
|
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
|
|
@@ -1746,6 +1750,10 @@ def test_flash_attn_kvcache(
|
|
|
dtype=torch.int32,
|
|
|
device=device,
|
|
|
)
|
|
|
+ if has_batch_idx:
|
|
|
+ cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[:batch_size]
|
|
|
+ else:
|
|
|
+ cache_batch_idx = None
|
|
|
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
|
|
|
if rotary_dim > 0:
|
|
|
angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi
|
|
@@ -1775,8 +1783,8 @@ def test_flash_attn_kvcache(
|
|
|
cos, sin = None, None
|
|
|
q_ro, k_ro = q, k
|
|
|
# k_cache[:, 64:] = -1
|
|
|
- k_cache_ref = k_cache.clone()
|
|
|
- v_cache_ref = v_cache.clone()
|
|
|
+ k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
|
|
|
+ v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
|
|
|
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
|
|
|
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
|
|
|
if new_kv:
|
|
@@ -1796,6 +1804,7 @@ def test_flash_attn_kvcache(
|
|
|
cos,
|
|
|
sin,
|
|
|
cache_seqlens,
|
|
|
+ cache_batch_idx,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
rotary_interleaved=rotary_interleaved,
|
|
@@ -1844,8 +1853,10 @@ def test_flash_attn_kvcache(
|
|
|
# Check that FlashAttention's numerical error is at most twice the numerical error
|
|
|
# of a Pytorch implementation.
|
|
|
if new_kv:
|
|
|
- assert torch.allclose(k_cache, k_cache_ref, rtol=1e-3, atol=1e-3)
|
|
|
- assert torch.equal(v_cache, v_cache_ref)
|
|
|
+ k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
|
|
|
+ v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
|
|
|
+ assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
|
|
|
+ assert torch.equal(v_cache_select, v_cache_ref)
|
|
|
assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5
|
|
|
|
|
|
|