|
@@ -1542,8 +1542,12 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
|
|
|
(1023, 1024),
|
|
|
],
|
|
|
)
|
|
|
+# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
|
|
|
+@pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
|
|
|
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
|
|
|
-def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
|
|
|
+def test_flash_attn_varlen_causal(
|
|
|
+ seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype
|
|
|
+):
|
|
|
if (
|
|
|
max(seqlen_q, seqlen_k) >= 2048
|
|
|
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
|
|
@@ -1559,8 +1563,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
nheads = 9
|
|
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
|
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
- k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
- v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
+
|
|
|
+ if paged_kv_block_size is None:
|
|
|
+ k = torch.randn(
|
|
|
+ batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
|
|
|
+ )
|
|
|
+ v = torch.randn(
|
|
|
+ batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
|
|
|
+ )
|
|
|
+ block_table = None
|
|
|
+ else:
|
|
|
+ k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(
|
|
|
+ seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype
|
|
|
+ )
|
|
|
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
|
|
|
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
|
|
|
(
|
|
@@ -1580,8 +1595,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
|
|
|
out_unpad = flash_attn_varlen_func(
|
|
|
q_unpad,
|
|
|
- k_unpad,
|
|
|
- v_unpad,
|
|
|
+ k_unpad if paged_kv_block_size is None else k_cache_paged,
|
|
|
+ v_unpad if paged_kv_block_size is None else v_cache_paged,
|
|
|
cu_seqlens_q,
|
|
|
cu_seqlens_k,
|
|
|
max_seqlen_q,
|
|
@@ -1589,6 +1604,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
0.0,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
+ block_table=block_table,
|
|
|
)
|
|
|
out = output_pad_fn(out_unpad)
|
|
|
out_ref, attn_ref = attention_ref(
|
|
@@ -1625,7 +1641,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
|
|
|
g = torch.randn_like(out)
|
|
|
do_o = (g.float() * out.float()).sum(-1)
|
|
|
- if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
+ test_backward = (d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None
|
|
|
+ if test_backward:
|
|
|
(
|
|
|
dq_unpad,
|
|
|
dk_unpad,
|
|
@@ -1661,7 +1678,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
# of a Pytorch implementation.
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
|
|
|
|
|
|
- if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
|
|
+ if test_backward:
|
|
|
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
|
|
|
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
|
|
|
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
|
|
@@ -1888,29 +1905,16 @@ def test_flash_attn_kvcache(
|
|
|
v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
|
|
block_table = None
|
|
|
else:
|
|
|
- num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
|
|
|
- k_cache_paged = torch.randn(
|
|
|
- num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
|
|
|
- )
|
|
|
- v_cache_paged = torch.randn(
|
|
|
- num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
|
|
|
- )
|
|
|
- block_table = rearrange(
|
|
|
- torch.randperm(num_blocks, dtype=torch.int32, device=device),
|
|
|
- "(b nblocks) -> b nblocks",
|
|
|
- b=batch_size,
|
|
|
+ (
|
|
|
+ k_cache,
|
|
|
+ v_cache,
|
|
|
+ block_table,
|
|
|
+ k_cache_paged,
|
|
|
+ v_cache_paged,
|
|
|
+ num_blocks,
|
|
|
+ ) = _generate_block_kvcache(
|
|
|
+ seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
|
|
|
)
|
|
|
- k_cache = rearrange(
|
|
|
- # pytorch 1.12 doesn't have indexing with int32
|
|
|
- k_cache_paged[block_table.to(dtype=torch.long).flatten()],
|
|
|
- "(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
|
|
- b=batch_size,
|
|
|
- )[:, :seqlen_k]
|
|
|
- v_cache = rearrange(
|
|
|
- v_cache_paged[block_table.to(dtype=torch.long).flatten()],
|
|
|
- "(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
|
|
- b=batch_size,
|
|
|
- )[:, :seqlen_k]
|
|
|
cache_seqlens = torch.randint(
|
|
|
0 if new_kv else 1,
|
|
|
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
|
|
@@ -2073,6 +2077,33 @@ def test_flash_attn_kvcache(
|
|
|
assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
|
|
|
|
|
|
|
|
|
+def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):
|
|
|
+ num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
|
|
|
+ k_cache_paged = torch.randn(
|
|
|
+ num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
|
|
|
+ )
|
|
|
+ v_cache_paged = torch.randn(
|
|
|
+ num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
|
|
|
+ )
|
|
|
+ block_table = rearrange(
|
|
|
+ torch.randperm(num_blocks, dtype=torch.int32, device=device),
|
|
|
+ "(b nblocks) -> b nblocks",
|
|
|
+ b=batch_size,
|
|
|
+ )
|
|
|
+ k_cache = rearrange(
|
|
|
+ # pytorch 1.12 doesn't have indexing with int32
|
|
|
+ k_cache_paged[block_table.to(dtype=torch.long).flatten()],
|
|
|
+ "(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
|
|
+ b=batch_size,
|
|
|
+ )[:, :seqlen_k]
|
|
|
+ v_cache = rearrange(
|
|
|
+ v_cache_paged[block_table.to(dtype=torch.long).flatten()],
|
|
|
+ "(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
|
|
+ b=batch_size,
|
|
|
+ )[:, :seqlen_k]
|
|
|
+ return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks
|
|
|
+
|
|
|
+
|
|
|
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
@pytest.mark.parametrize("dtype", [torch.float16])
|
|
|
@pytest.mark.parametrize("causal", [False, True])
|