123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- from typing import List, Optional, Tuple
- import pytest
- import torch
- import aphrodite.attention.backends.flash_attn # noqa: F401
- NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
- HEAD_SIZES = [128, 256]
- BLOCK_SIZES = [16, 32]
- DTYPES = [torch.float16, torch.bfloat16]
- # one value large enough to test overflow in index calculation.
- # one value small enough to test the schema op check
- NUM_BLOCKS = [32768, 2048]
- def ref_paged_attn(
- query: torch.Tensor,
- key_cache: torch.Tensor,
- value_cache: torch.Tensor,
- query_lens: List[int],
- kv_lens: List[int],
- block_tables: torch.Tensor,
- scale: float,
- sliding_window: Optional[int] = None,
- soft_cap: Optional[float] = None,
- ) -> torch.Tensor:
- num_seqs = len(query_lens)
- block_tables = block_tables.cpu().numpy()
- _, block_size, num_kv_heads, head_size = key_cache.shape
- outputs: List[torch.Tensor] = []
- start_idx = 0
- for i in range(num_seqs):
- query_len = query_lens[i]
- kv_len = kv_lens[i]
- q = query[start_idx:start_idx + query_len]
- q *= scale
- num_kv_blocks = (kv_len + block_size - 1) // block_size
- block_indices = block_tables[i, :num_kv_blocks]
- k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
- k = k[:kv_len]
- v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
- v = v[:kv_len]
- if q.shape[1] != k.shape[1]:
- k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
- v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
- attn = torch.einsum("qhd,khd->hqk", q, k).float()
- empty_mask = torch.ones(query_len, kv_len)
- mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
- if sliding_window is not None:
- sliding_window_mask = torch.triu(empty_mask,
- diagonal=kv_len -
- (query_len + sliding_window) +
- 1).bool().logical_not()
- mask |= sliding_window_mask
- if soft_cap is not None:
- attn = soft_cap * torch.tanh(attn / soft_cap)
- attn.masked_fill_(mask, float("-inf"))
- attn = torch.softmax(attn, dim=-1).to(v.dtype)
- out = torch.einsum("hqk,khd->qhd", attn, v)
- outputs.append(out)
- start_idx += query_len
- return torch.cat(outputs, dim=0)
- @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
- @pytest.mark.parametrize("num_heads", NUM_HEADS)
- @pytest.mark.parametrize("head_size", HEAD_SIZES)
- @pytest.mark.parametrize("block_size", BLOCK_SIZES)
- @pytest.mark.parametrize("dtype", DTYPES)
- @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
- @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
- @torch.inference_mode()
- def test_flash_attn_with_paged_kv(
- kv_lens: List[int],
- num_heads: Tuple[int, int],
- head_size: int,
- dtype: torch.dtype,
- block_size: int,
- soft_cap: Optional[float],
- num_blocks: int,
- ) -> None:
- torch.set_default_device("cuda")
- torch.cuda.manual_seed_all(0)
- num_seqs = len(kv_lens)
- num_query_heads = num_heads[0]
- num_kv_heads = num_heads[1]
- assert num_query_heads % num_kv_heads == 0
- max_kv_len = max(kv_lens)
- scale = head_size**-0.5
- query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
- key_cache = torch.randn(num_blocks,
- block_size,
- num_kv_heads,
- head_size,
- dtype=dtype)
- value_cache = torch.randn_like(key_cache)
- kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
- max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
- block_tables = torch.randint(0,
- num_blocks,
- (num_seqs, max_num_blocks_per_seq),
- dtype=torch.int32)
- output = torch.ops.aphrodite.flash_attn_with_kvcache(
- decode_query=query.unsqueeze(1),
- key_cache=key_cache,
- value_cache=value_cache,
- softmax_scale=scale,
- causal=True,
- block_table=block_tables,
- cache_seqlens=kv_lens_tensor,
- softcap=soft_cap if soft_cap is not None else 0,
- ).squeeze(1)
- if num_blocks <= 2048:
- test_utils = ["test_faketensor", "test_schema"]
- else:
- test_utils = ["test_faketensor"]
- torch.library.opcheck(torch.ops.aphrodite.flash_attn_with_kvcache,
- args=tuple(),
- kwargs=dict(
- decode_query=query.unsqueeze(1),
- key_cache=key_cache,
- value_cache=value_cache,
- softmax_scale=scale,
- causal=True,
- block_table=block_tables,
- cache_seqlens=kv_lens_tensor,
- softcap=soft_cap if soft_cap is not None else 0,
- ),
- test_utils=test_utils)
- ref_output = ref_paged_attn(
- query=query,
- key_cache=key_cache,
- value_cache=value_cache,
- query_lens=[1] * num_seqs,
- kv_lens=kv_lens,
- block_tables=block_tables,
- scale=scale,
- soft_cap=soft_cap,
- )
- torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
- f"{torch.max(torch.abs(output - ref_output))}"
- @pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
- @pytest.mark.parametrize("num_heads", NUM_HEADS)
- @pytest.mark.parametrize("head_size", HEAD_SIZES)
- @pytest.mark.parametrize("block_size", BLOCK_SIZES)
- @pytest.mark.parametrize("sliding_window", [None])
- @pytest.mark.parametrize("dtype", DTYPES)
- @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
- @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
- @torch.inference_mode()
- def test_varlen_with_paged_kv(
- seq_lens: List[Tuple[int, int]],
- num_heads: Tuple[int, int],
- head_size: int,
- sliding_window: Optional[int],
- dtype: torch.dtype,
- block_size: int,
- soft_cap: Optional[float],
- num_blocks: int,
- ) -> None:
- torch.set_default_device("cuda")
- torch.cuda.manual_seed_all(0)
- num_seqs = len(seq_lens)
- query_lens = [x[0] for x in seq_lens]
- kv_lens = [x[1] for x in seq_lens]
- num_query_heads = num_heads[0]
- num_kv_heads = num_heads[1]
- assert num_query_heads % num_kv_heads == 0
- max_query_len = max(query_lens)
- max_kv_len = max(kv_lens)
- window_size = ((sliding_window,
- sliding_window) if sliding_window is not None else
- (-1, -1))
- scale = head_size**-0.5
- query = torch.randn(sum(query_lens),
- num_query_heads,
- head_size,
- dtype=dtype)
- key_cache = torch.randn(num_blocks,
- block_size,
- num_kv_heads,
- head_size,
- dtype=dtype)
- value_cache = torch.randn_like(key_cache)
- cu_query_lens = torch.tensor([0] + query_lens,
- dtype=torch.int32).cumsum(dim=0,
- dtype=torch.int32)
- cu_kv_lens = torch.tensor([0] + kv_lens,
- dtype=torch.int32).cumsum(dim=0,
- dtype=torch.int32)
- max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
- block_tables = torch.randint(0,
- num_blocks,
- (num_seqs, max_num_blocks_per_seq),
- dtype=torch.int32)
- output = torch.ops.aphrodite.flash_attn_varlen_func(
- q=query,
- k=key_cache,
- v=value_cache,
- cu_seqlens_q=cu_query_lens,
- cu_seqlens_k=cu_kv_lens,
- max_seqlen_q=max_query_len,
- max_seqlen_k=max_kv_len,
- softmax_scale=scale,
- causal=True,
- window_size=window_size,
- block_table=block_tables,
- softcap=soft_cap if soft_cap is not None else 0,
- )
- if num_blocks <= 2048:
- test_utils = ["test_faketensor", "test_schema"]
- else:
- test_utils = ["test_faketensor"]
- torch.library.opcheck(torch.ops.aphrodite.flash_attn_varlen_func,
- args=tuple(),
- kwargs=dict(
- q=query,
- k=key_cache,
- v=value_cache,
- cu_seqlens_q=cu_query_lens,
- cu_seqlens_k=cu_kv_lens,
- max_seqlen_q=max_query_len,
- max_seqlen_k=max_kv_len,
- softmax_scale=scale,
- causal=True,
- window_size=window_size,
- block_table=block_tables,
- softcap=soft_cap if soft_cap is not None else 0,
- ),
- test_utils=test_utils)
- ref_output = ref_paged_attn(
- query=query,
- key_cache=key_cache,
- value_cache=value_cache,
- query_lens=query_lens,
- kv_lens=kv_lens,
- block_tables=block_tables,
- scale=scale,
- sliding_window=sliding_window,
- soft_cap=soft_cap,
- )
- torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
- f"{torch.max(torch.abs(output - ref_output))}"
|