import random from typing import List, Optional, Tuple import pytest import torch from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from aphrodite import _custom_ops as ops from aphrodite.common.utils import get_max_shared_memory_bytes, is_hip from tests.kernels.utils import opcheck from .allclose_default import get_default_atol, get_default_rtol FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing PARTITION_SIZE = 512 # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} DTYPES = [torch.half, torch.bfloat16, torch.float ] if not is_hip() else [torch.half, torch.bfloat16] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256 ] if not is_hip() else [64, 80, 96, 112, 128] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] def ref_masked_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, scale: float, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() if attn_mask is not None: attn_weights = attn_weights + attn_mask.float() attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) out = torch.einsum("hqk,khd->qhd", attn_weights, value) return out def ref_single_query_cached_kv_attention( output: torch.Tensor, query: torch.Tensor, num_queries_per_kv: int, key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], ) -> None: num_query_heads = query.shape[1] num_kv_heads = value_cache.shape[1] head_size = value_cache.shape[2] block_size = value_cache.shape[3] num_seqs = query.shape[0] block_tables_lst = block_tables.cpu().tolist() seq_lens_lst = seq_lens.cpu().tolist() for i in range(num_seqs): q = query[i].unsqueeze(0) block_table = block_tables_lst[i] seq_len = int(seq_lens_lst[i]) keys_lst: List[torch.Tensor] = [] values_lst: List[torch.Tensor] = [] for j in range(seq_len): block_number = int(block_table[j // block_size]) block_offset = j % block_size k = key_cache[block_number, :, :, block_offset, :] k = k.reshape(num_kv_heads, head_size) keys_lst.append(k) v = value_cache[block_number, :, :, block_offset] values_lst.append(v) keys = torch.stack(keys_lst, dim=0) values = torch.stack(values_lst, dim=0) if num_queries_per_kv > 1: # Handle MQA and GQA keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) alibi_bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. position_ids = torch.arange(seq_len).int() alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) out = ref_masked_attention(q, keys, values, scale, alibi_bias) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) @pytest.mark.parametrize("version", ["v1", "v2"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("use_alibi", USE_ALIBI) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_paged_attention( kv_cache_factory, version: str, num_seqs: int, num_heads: Tuple[int, int], head_size: int, use_alibi: bool, block_size: int, dtype: torch.dtype, kv_cache_dtype: str, seed: int, device: str, ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) query.uniform_(-scale, scale) assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] seq_lens[-1] = MAX_SEQ_LEN max_seq_len = max(seq_lens) seq_lens = torch.tensor(seq_lens, dtype=torch.int) # Create the block tables. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables_lst: List[List[int]] = [] for _ in range(num_seqs): block_table = [ random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] block_tables_lst.append(block_table) block_tables = torch.tensor(block_tables_lst, dtype=torch.int) # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype, dtype, seed, device) key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale k_scale = v_scale = 1.0 # Call the paged attention kernel. output = torch.empty_like(query) if version == "v1": ops.paged_attention_v1( output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, ) opcheck(torch.ops._C.paged_attention_v1, (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0])) elif version == "v2": num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( size=(num_seqs, num_heads, num_partitions, head_size), dtype=output.dtype, ) exp_sums = torch.empty( size=(num_seqs, num_heads, num_partitions), dtype=torch.float32, ) max_logits = torch.empty_like(exp_sums) ops.paged_attention_v2( output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, ) opcheck(torch.ops._C.paged_attention_v2, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0])) else: raise AssertionError(f"Unknown version: {version}") # Run the reference implementation. if kv_cache_dtype == "fp8": # Convert cache data back to dtype. x = 16 // torch.tensor([], dtype=dtype).element_size() key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) ops.convert_fp8(dequantized_key_cache, key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) ops.convert_fp8(dequantized_value_cache, value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) ref_single_query_cached_kv_attention( ref_output, query, num_queries_per_kv, key_cache, value_cache, block_tables, seq_lens, scale, alibi_slopes, ) # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two # outputs. Thus, we use a relaxed tolerance for the test. atol = get_default_atol(output) if is_hip() else 1e-3 rtol = get_default_rtol(output) if is_hip() else 1e-5 # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # so we use a relaxed tolerance for the test. atol, rtol = 1e-3, 1e-5 if kv_cache_dtype == "fp8": atol, rtol = 1e-2, 1e-5 torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) def ref_multi_query_kv_attention( cu_seq_lens: List[int], query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, scale: float, dtype: torch.dtype, ) -> torch.Tensor: num_seqs = len(cu_seq_lens) - 1 ref_outputs: List[torch.Tensor] = [] for i in range(num_seqs): start_idx = cu_seq_lens[i] end_idx = cu_seq_lens[i + 1] seq_len = end_idx - start_idx # Create attention mask. attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask.to(dtype=dtype) ref_output = ref_masked_attention( query[start_idx:end_idx], key[start_idx:end_idx], value[start_idx:end_idx], scale, attn_mask=attn_mask, ) ref_outputs.append(ref_output) return torch.cat(ref_outputs, dim=0) # TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, num_heads: Tuple[int, int], head_size: int, dtype: torch.dtype, seed: int, device: str, ) -> None: random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.set_default_device(device) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use # a smaller MAX_SEQ_LEN here. max_len = min(MAX_SEQ_LEN, 4096) seq_lens = random.sample(range(1, max_len), num_seqs) num_tokens = sum(seq_lens) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads qkv = torch.empty(num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype) qkv.uniform_(-scale, scale) query, key, value = qkv.split( [num_query_heads, num_kv_heads, num_kv_heads], dim=1) num_queries_per_kv = num_query_heads // num_kv_heads if num_queries_per_kv > 1: # Handle MQA and GQA key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) output = xops.memory_efficient_attention_forward( query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0), attn_bias=attn_bias, p=0.0, scale=scale, ) output = output.squeeze(0) cu_seq_lens = [0] for seq_len in seq_lens: cu_seq_lens.append(cu_seq_lens[-1] + seq_len) ref_output = ref_multi_query_kv_attention( cu_seq_lens, query, key, value, scale, dtype, ) atol = get_default_atol(output) if is_hip() else 1e-3 rtol = get_default_rtol(output) if is_hip() else 1e-5 torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)