123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- import random
- from typing import List, Optional, Tuple
- import pytest
- import torch
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
- from aphrodite._C import ops
- from aphrodite.common.utils import get_max_shared_memory_bytes
- FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
- # This will change depending on the compute capability.
- # - 512 as a buffer
- MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
- NUM_BLOCKS = 40000 # Arbitrary values for testing
- PARTITION_SIZE = 512
- DTYPES = [torch.half, torch.bfloat16, torch.float]
- NUM_GEN_SEQS = [7] # Arbitrary values for testing
- NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
- NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
- HEAD_SIZES = [64, 80, 96, 112, 128, 256]
- BLOCK_SIZES = [16, 32]
- USE_ALIBI = [False, True]
- SEEDS = [0]
- DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
- def ref_masked_attention(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- scale: float,
- attn_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
- if attn_mask is not None:
- attn_weights = attn_weights + attn_mask.float()
- attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
- out = torch.einsum("hqk,khd->qhd", attn_weights, value)
- return out
- def ref_single_query_cached_kv_attention(
- output: torch.Tensor,
- query: torch.Tensor,
- num_queries_per_kv: int,
- key_cache: torch.Tensor,
- value_cache: torch.Tensor,
- block_tables: torch.Tensor,
- context_lens: torch.Tensor,
- scale: float,
- alibi_slopes: Optional[torch.Tensor],
- ) -> None:
- num_query_heads = query.shape[1]
- num_kv_heads = value_cache.shape[1]
- head_size = value_cache.shape[2]
- block_size = value_cache.shape[3]
- num_seqs = query.shape[0]
- block_tables = block_tables.cpu().tolist()
- context_lens = context_lens.cpu().tolist()
- for i in range(num_seqs):
- q = query[i].unsqueeze(0)
- block_table = block_tables[i]
- context_len = int(context_lens[i])
- keys = []
- values = []
- for j in range(context_len):
- block_number = int(block_table[j // block_size])
- block_offset = j % block_size
- k = key_cache[block_number, :, :, block_offset, :]
- k = k.reshape(num_kv_heads, head_size)
- keys.append(k)
- v = value_cache[block_number, :, :, block_offset]
- values.append(v)
- keys = torch.stack(keys, dim=0)
- values = torch.stack(values, dim=0)
- if num_queries_per_kv > 1:
- # Handle MQA and GQA
- keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
- values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
- alibi_bias = None
- if alibi_slopes is not None:
- # Create the ALiBi bias used in the paged attention kernel.
- position_ids = torch.arange(context_len, device=query.device).int()
- alibi_bias = (position_ids - context_len + 1).float()
- alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
- 1, 1, -1)
- out = ref_masked_attention(q, keys, values, scale, alibi_bias)
- out = out.view(num_query_heads, head_size)
- output[i].copy_(out, non_blocking=True)
- @pytest.mark.parametrize("version", ["v1", "v2"])
- @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
- @pytest.mark.parametrize("num_heads", NUM_HEADS)
- @pytest.mark.parametrize("head_size", HEAD_SIZES)
- @pytest.mark.parametrize("use_alibi", USE_ALIBI)
- @pytest.mark.parametrize("block_size", BLOCK_SIZES)
- @pytest.mark.parametrize("dtype", DTYPES)
- @pytest.mark.parametrize("seed", SEEDS)
- @pytest.mark.parametrize("device", DEVICES)
- def test_paged_attention(
- kv_cache_factory,
- version: str,
- num_seqs: int,
- num_heads: Tuple[int, int],
- head_size: int,
- use_alibi: bool,
- block_size: int,
- dtype: torch.dtype,
- seed: int,
- device: int,
- ) -> None:
- random.seed(seed)
- torch.random.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- gpu_id = f"cuda:{device}"
- scale = float(1.0 / (head_size**0.5))
- num_query_heads, num_kv_heads = num_heads
- query = torch.empty(num_seqs,
- num_query_heads,
- head_size,
- dtype=dtype,
- device=gpu_id)
- query.uniform_(-scale, scale)
- assert num_query_heads % num_kv_heads == 0
- num_queries_per_kv = num_query_heads // num_kv_heads
- alibi_slopes = None
- if use_alibi:
- alibi_slopes = torch.randn(num_query_heads,
- dtype=torch.float,
- device=gpu_id)
- context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
- context_lens[-1] = MAX_SEQ_LEN
- max_context_len = max(context_lens)
- context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
- # Create the block tables.
- max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
- block_tables = []
- for _ in range(num_seqs):
- block_table = [
- random.randint(0, NUM_BLOCKS - 1)
- for _ in range(max_num_blocks_per_seq)
- ]
- block_tables.append(block_table)
- block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
- # Create the KV caches.
- key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
- num_kv_heads, head_size, dtype,
- seed, gpu_id)
- key_cache, value_cache = key_caches[0], value_caches[0]
- # Call the paged attention kernel.
- output = torch.empty_like(query)
- if version == "v1":
- ops.paged_attention_v1(
- output,
- query,
- key_cache,
- value_cache,
- num_kv_heads,
- scale,
- block_tables,
- context_lens,
- block_size,
- max_context_len,
- alibi_slopes,
- )
- elif version == "v2":
- num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
- PARTITION_SIZE)
- assert PARTITION_SIZE % block_size == 0
- num_seqs, num_heads, head_size = output.shape
- tmp_output = torch.empty(
- size=(num_seqs, num_heads, num_partitions, head_size),
- dtype=output.dtype,
- device=output.device,
- )
- exp_sums = torch.empty(
- size=(num_seqs, num_heads, num_partitions),
- dtype=torch.float32,
- device=output.device,
- )
- max_logits = torch.empty_like(exp_sums)
- ops.paged_attention_v2(
- output,
- exp_sums,
- max_logits,
- tmp_output,
- query,
- key_cache,
- value_cache,
- num_kv_heads,
- scale,
- block_tables,
- context_lens,
- block_size,
- max_context_len,
- alibi_slopes,
- )
- else:
- raise AssertionError(f"Unknown version: {version}")
- # Run the reference implementation.
- ref_output = torch.empty_like(query)
- ref_single_query_cached_kv_attention(
- ref_output,
- query,
- num_queries_per_kv,
- key_cache,
- value_cache,
- block_tables,
- context_lens,
- scale,
- alibi_slopes,
- )
- # NOTE: Due to the kernel-level differences in the two
- # implementations, there is a small numerical difference in the two
- # outputs. Thus, we use a relaxed tolerance for the test.
- assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
- def ref_multi_query_kv_attention(
- cu_seq_lens: List[int],
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- scale: float,
- dtype: torch.dtype,
- ) -> torch.Tensor:
- num_seqs = len(cu_seq_lens) - 1
- ref_outputs = []
- for i in range(num_seqs):
- start_idx = cu_seq_lens[i]
- end_idx = cu_seq_lens[i + 1]
- seq_len = end_idx - start_idx
- # Create attention mask.
- attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
- diagonal=1)
- attn_mask = attn_mask * torch.finfo(dtype).min
- attn_mask = attn_mask.to(dtype=dtype, device=query.device)
- ref_output = ref_masked_attention(
- query[start_idx:end_idx],
- key[start_idx:end_idx],
- value[start_idx:end_idx],
- scale,
- attn_mask=attn_mask,
- )
- ref_outputs.append(ref_output)
- ref_output = torch.cat(ref_outputs, dim=0)
- return ref_output
- # TODO: Add tests for USE_ALIBI=True.
- @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
- @pytest.mark.parametrize("num_heads", NUM_HEADS)
- @pytest.mark.parametrize("head_size", HEAD_SIZES)
- @pytest.mark.parametrize("dtype", DTYPES)
- @pytest.mark.parametrize("seed", SEEDS)
- @pytest.mark.parametrize("device", DEVICES)
- @torch.inference_mode()
- def test_multi_query_kv_attention(
- num_seqs: int,
- num_heads: Tuple[int, int],
- head_size: int,
- dtype: torch.dtype,
- seed: int,
- device: int,
- ) -> None:
- random.seed(seed)
- torch.random.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- gpu_id = f"cuda:{device}"
- # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
- # As the xformers library is already tested with its own tests, we can use
- # a smaller MAX_SEQ_LEN here.
- max_len = min(MAX_SEQ_LEN, 4096)
- seq_lens = random.sample(range(1, max_len), num_seqs)
- num_tokens = sum(seq_lens)
- scale = float(1.0 / (head_size**0.5))
- num_query_heads, num_kv_heads = num_heads
- qkv = torch.empty(num_tokens,
- num_query_heads + 2 * num_kv_heads,
- head_size,
- dtype=dtype,
- device=gpu_id)
- qkv.uniform_(-scale, scale)
- query, key, value = qkv.split(
- [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
- num_queries_per_kv = num_query_heads // num_kv_heads
- if num_queries_per_kv > 1:
- # Handle MQA and GQA
- key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
- value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
- attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
- output = xops.memory_efficient_attention_forward(
- query.unsqueeze(0),
- key.unsqueeze(0),
- value.unsqueeze(0),
- attn_bias=attn_bias,
- p=0.0,
- scale=scale,
- )
- output = output.squeeze(0)
- cu_seq_lens = [0]
- for seq_len in seq_lens:
- cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
- ref_output = ref_multi_query_kv_attention(
- cu_seq_lens,
- query,
- key,
- value,
- scale,
- dtype,
- )
- assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|