|
@@ -708,7 +708,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
|
|
|
# @pytest.mark.parametrize('seqlen', [128])
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
|
|
# @pytest.mark.parametrize('dropout_p', [0.0])
|
|
|
-def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
|
|
|
+def test_flash_attn_varlen_qkvpacked(
|
|
|
+ seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype
|
|
|
+):
|
|
|
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
|
|
|
pytest.skip() # Reference implementation OOM
|
|
|
device = "cuda"
|
|
@@ -1698,7 +1700,9 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
],
|
|
|
)
|
|
|
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
|
|
-def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype):
|
|
|
+def test_flash_attn_splitkv(
|
|
|
+ seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype
|
|
|
+):
|
|
|
if swap_sq_sk:
|
|
|
seqlen_q, seqlen_k = seqlen_k, seqlen_q
|
|
|
device = "cuda"
|
|
@@ -1800,7 +1804,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al
|
|
|
@pytest.mark.parametrize("new_kv", [False, True])
|
|
|
# @pytest.mark.parametrize("new_kv", [False])
|
|
|
@pytest.mark.parametrize("alibi", [False, True])
|
|
|
-# @pytest.mark.parametrize("alibi", [True])
|
|
|
+# @pytest.mark.parametrize("alibi", [False])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
|
# @pytest.mark.parametrize("local", [False])
|
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
@@ -1811,10 +1815,12 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al
|
|
|
# @pytest.mark.parametrize("rotary_interleaved", [False])
|
|
|
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
|
|
|
# @pytest.mark.parametrize("rotary_fraction", [0.0])
|
|
|
+# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
|
|
|
+@pytest.mark.parametrize("paged_kv_block_size", [256, 512])
|
|
|
@pytest.mark.parametrize("has_batch_idx", [False, True])
|
|
|
# @pytest.mark.parametrize("has_batch_idx", [False])
|
|
|
-@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
|
|
|
-# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
|
|
|
+@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
|
|
|
+# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
|
|
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
|
|
# @pytest.mark.parametrize('d', [56, 80])
|
|
|
# @pytest.mark.parametrize("d", [128])
|
|
@@ -1840,6 +1846,7 @@ def test_flash_attn_kvcache(
|
|
|
seqlen_k,
|
|
|
d,
|
|
|
has_batch_idx,
|
|
|
+ paged_kv_block_size,
|
|
|
rotary_fraction,
|
|
|
rotary_interleaved,
|
|
|
seqlen_new_eq_seqlen_q,
|
|
@@ -1855,6 +1862,8 @@ def test_flash_attn_kvcache(
|
|
|
pytest.skip()
|
|
|
if not new_kv and rotary_fraction > 0.0:
|
|
|
pytest.skip()
|
|
|
+ if has_batch_idx and paged_kv_block_size is not None:
|
|
|
+ pytest.skip()
|
|
|
device = "cuda"
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
@@ -1873,10 +1882,35 @@ def test_flash_attn_kvcache(
|
|
|
v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
|
|
|
else:
|
|
|
k, v = None, None
|
|
|
- k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
|
|
- v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
|
|
+ if paged_kv_block_size is None:
|
|
|
+ k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
|
|
+ v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
|
|
+ block_table = None
|
|
|
+ else:
|
|
|
+ num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
|
|
|
+ k_cache_paged = torch.randn(
|
|
|
+ num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
|
|
|
+ )
|
|
|
+ v_cache_paged = torch.randn(
|
|
|
+ num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
|
|
|
+ )
|
|
|
+ block_table = rearrange(
|
|
|
+ torch.randperm(num_blocks, dtype=torch.int32, device=device),
|
|
|
+ "(b nblocks) -> b nblocks",
|
|
|
+ b=batch_size,
|
|
|
+ )
|
|
|
+ k_cache = rearrange(
|
|
|
+ k_cache_paged[block_table.flatten()],
|
|
|
+ "(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
|
|
+ b=batch_size,
|
|
|
+ )[:, :seqlen_k]
|
|
|
+ v_cache = rearrange(
|
|
|
+ v_cache_paged[block_table.flatten()],
|
|
|
+ "(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
|
|
+ b=batch_size,
|
|
|
+ )[:, :seqlen_k]
|
|
|
cache_seqlens = torch.randint(
|
|
|
- 0,
|
|
|
+ 0 if new_kv else 1,
|
|
|
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
|
|
|
(seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
|
|
|
if new_kv
|
|
@@ -1903,7 +1937,15 @@ def test_flash_attn_kvcache(
|
|
|
alibi_slopes, attn_bias = None, None
|
|
|
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
|
|
|
if rotary_dim > 0:
|
|
|
- angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi
|
|
|
+ angle = (
|
|
|
+ torch.rand(
|
|
|
+ seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size,
|
|
|
+ rotary_dim // 2,
|
|
|
+ device=device,
|
|
|
+ )
|
|
|
+ * 2
|
|
|
+ * math.pi
|
|
|
+ )
|
|
|
cos = torch.cos(angle).to(dtype=dtype)
|
|
|
sin = torch.sin(angle).to(dtype=dtype)
|
|
|
if causal or local:
|
|
@@ -1942,14 +1984,15 @@ def test_flash_attn_kvcache(
|
|
|
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
|
|
|
out = flash_attn_with_kvcache(
|
|
|
q,
|
|
|
- k_cache,
|
|
|
- v_cache,
|
|
|
+ k_cache if paged_kv_block_size is None else k_cache_paged,
|
|
|
+ v_cache if paged_kv_block_size is None else v_cache_paged,
|
|
|
k,
|
|
|
v,
|
|
|
- cos,
|
|
|
- sin,
|
|
|
- cache_seqlens,
|
|
|
- cache_batch_idx,
|
|
|
+ rotary_cos=cos,
|
|
|
+ rotary_sin=sin,
|
|
|
+ cache_seqlens=cache_seqlens,
|
|
|
+ cache_batch_idx=cache_batch_idx,
|
|
|
+ block_table=block_table,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
rotary_interleaved=rotary_interleaved,
|
|
@@ -2000,8 +2043,20 @@ def test_flash_attn_kvcache(
|
|
|
# Check that FlashAttention's numerical error is at most twice the numerical error
|
|
|
# of a Pytorch implementation.
|
|
|
if new_kv:
|
|
|
- k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
|
|
|
- v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
|
|
|
+ if paged_kv_block_size is None:
|
|
|
+ k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
|
|
|
+ v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
|
|
|
+ else:
|
|
|
+ k_cache_select = rearrange(
|
|
|
+ k_cache_paged[block_table.flatten()],
|
|
|
+ "(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
|
|
+ b=batch_size,
|
|
|
+ )[:, :seqlen_k]
|
|
|
+ v_cache_select = rearrange(
|
|
|
+ v_cache_paged[block_table.flatten()],
|
|
|
+ "(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
|
|
+ b=batch_size,
|
|
|
+ )[:, :seqlen_k]
|
|
|
assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
|
|
|
assert torch.equal(v_cache_select, v_cache_ref)
|
|
|
mult = 3 if not alibi else 5
|
|
@@ -2280,8 +2335,6 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc
|
|
|
assert torch.equal(dq, dq0)
|
|
|
|
|
|
|
|
|
-
|
|
|
-
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|