import random from typing import List, Optional, Tuple import pytest import torch from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from aphrodite._C import ops from aphrodite.common.utils import get_max_shared_memory_bytes FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 NUM_BLOCKS = 40000 # Arbitrary values for testing PARTITION_SIZE = 512 DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing HEAD_SIZES = [64, 80, 96, 112, 128, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] SEEDS = [0] DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] def ref_masked_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, scale: float, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() if attn_mask is not None: attn_weights = attn_weights + attn_mask.float() attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) out = torch.einsum("hqk,khd->qhd", attn_weights, value) return out def ref_single_query_cached_kv_attention( output: torch.Tensor, query: torch.Tensor, num_queries_per_kv: int, key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, context_lens: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], ) -> None: num_query_heads = query.shape[1] num_kv_heads = value_cache.shape[1] head_size = value_cache.shape[2] block_size = value_cache.shape[3] num_seqs = query.shape[0] block_tables = block_tables.cpu().tolist() context_lens = context_lens.cpu().tolist() for i in range(num_seqs): q = query[i].unsqueeze(0) block_table = block_tables[i] context_len = int(context_lens[i]) keys = [] values = [] for j in range(context_len): block_number = int(block_table[j // block_size]) block_offset = j % block_size k = key_cache[block_number, :, :, block_offset, :] k = k.reshape(num_kv_heads, head_size) keys.append(k) v = value_cache[block_number, :, :, block_offset] values.append(v) keys = torch.stack(keys, dim=0) values = torch.stack(values, dim=0) if num_queries_per_kv > 1: # Handle MQA and GQA keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) alibi_bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. position_ids = torch.arange(context_len, device=query.device).int() alibi_bias = (position_ids - context_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) out = ref_masked_attention(q, keys, values, scale, alibi_bias) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) @pytest.mark.parametrize("version", ["v1", "v2"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("use_alibi", USE_ALIBI) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) def test_paged_attention( kv_cache_factory, version: str, num_seqs: int, num_heads: Tuple[int, int], head_size: int, use_alibi: bool, block_size: int, dtype: torch.dtype, seed: int, device: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) gpu_id = f"cuda:{device}" scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype, device=gpu_id) query.uniform_(-scale, scale) assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=gpu_id) context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] context_lens[-1] = MAX_SEQ_LEN max_context_len = max(context_lens) context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id) # Create the block tables. max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] block_tables.append(block_table) block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id) # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, dtype, seed, gpu_id) key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. output = torch.empty_like(query) if version == "v1": ops.paged_attention_v1( output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes, ) elif version == "v2": num_partitions = ((max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( size=(num_seqs, num_heads, num_partitions, head_size), dtype=output.dtype, device=output.device, ) exp_sums = torch.empty( size=(num_seqs, num_heads, num_partitions), dtype=torch.float32, device=output.device, ) max_logits = torch.empty_like(exp_sums) ops.paged_attention_v2( output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes, ) else: raise AssertionError(f"Unknown version: {version}") # Run the reference implementation. ref_output = torch.empty_like(query) ref_single_query_cached_kv_attention( ref_output, query, num_queries_per_kv, key_cache, value_cache, block_tables, context_lens, scale, alibi_slopes, ) # NOTE: Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two # outputs. Thus, we use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) def ref_multi_query_kv_attention( cu_seq_lens: List[int], query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, scale: float, dtype: torch.dtype, ) -> torch.Tensor: num_seqs = len(cu_seq_lens) - 1 ref_outputs = [] for i in range(num_seqs): start_idx = cu_seq_lens[i] end_idx = cu_seq_lens[i + 1] seq_len = end_idx - start_idx # Create attention mask. attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask.to(dtype=dtype, device=query.device) ref_output = ref_masked_attention( query[start_idx:end_idx], key[start_idx:end_idx], value[start_idx:end_idx], scale, attn_mask=attn_mask, ) ref_outputs.append(ref_output) ref_output = torch.cat(ref_outputs, dim=0) return ref_output # TODO: Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, num_heads: Tuple[int, int], head_size: int, dtype: torch.dtype, seed: int, device: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) gpu_id = f"cuda:{device}" # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use # a smaller MAX_SEQ_LEN here. max_len = min(MAX_SEQ_LEN, 4096) seq_lens = random.sample(range(1, max_len), num_seqs) num_tokens = sum(seq_lens) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads qkv = torch.empty(num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype, device=gpu_id) qkv.uniform_(-scale, scale) query, key, value = qkv.split( [num_query_heads, num_kv_heads, num_kv_heads], dim=1) num_queries_per_kv = num_query_heads // num_kv_heads if num_queries_per_kv > 1: # Handle MQA and GQA key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) output = xops.memory_efficient_attention_forward( query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0), attn_bias=attn_bias, p=0.0, scale=scale, ) output = output.squeeze(0) cu_seq_lens = [0] for seq_len in seq_lens: cu_seq_lens.append(cu_seq_lens[-1] + seq_len) ref_output = ref_multi_query_kv_attention( cu_seq_lens, query, key, value, scale, dtype, ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)