123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253 |
- from dataclasses import dataclass
- from typing import List, Optional, Tuple
- import torch
- from aphrodite import _custom_ops as ops
- from aphrodite.triton_utils import HAS_TRITON
- if HAS_TRITON:
- from aphrodite.attention.ops.prefix_prefill import context_attention_fwd
- # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
- _PARTITION_SIZE = 512
- @dataclass
- class PagedAttentionMetadata:
- """Metadata for PagedAttention."""
- # (batch_size,). The length of sequences (entire tokens seen so far) per
- # sequence.
- seq_lens_tensor: Optional[torch.Tensor]
- # Maximum sequence length in the batch. 0 if it is prefill-only batch.
- max_decode_seq_len: int
- # (batch_size, max_blocks_per_seq).
- # Block addresses per sequence. (Seq id -> list of physical block)
- # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
- # in the kv cache. Each block can contain up to block_size tokens.
- # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
- # captured.
- block_tables: Optional[torch.Tensor]
- class PagedAttention:
- @staticmethod
- def get_supported_head_sizes() -> List[int]:
- return [64, 80, 96, 112, 120, 128, 192, 256]
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- ) -> Tuple[int, ...]:
- return (2, num_blocks, block_size * num_kv_heads * head_size)
- @staticmethod
- def split_kv_cache(
- kv_cache: torch.Tensor,
- num_kv_heads: int,
- head_size: int,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- x = 16 // kv_cache.element_size()
- num_blocks = kv_cache.shape[1]
- key_cache = kv_cache[0]
- key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
- -1, x)
- value_cache = kv_cache[1]
- value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
- return key_cache, value_cache
- @staticmethod
- def write_to_paged_cache(
- key: torch.Tensor,
- value: torch.Tensor,
- key_cache: torch.Tensor,
- value_cache: torch.Tensor,
- slot_mapping: torch.Tensor,
- kv_cache_dtype: str,
- k_scale: float,
- v_scale: float,
- ) -> None:
- ops.reshape_and_cache(
- key,
- value,
- key_cache,
- value_cache,
- slot_mapping.flatten(),
- kv_cache_dtype,
- k_scale,
- v_scale,
- )
- @staticmethod
- def forward_decode(
- query: torch.Tensor,
- key_cache: torch.Tensor,
- value_cache: torch.Tensor,
- block_tables: torch.Tensor,
- seq_lens: torch.Tensor,
- max_seq_len: int,
- kv_cache_dtype: str,
- num_kv_heads: int,
- scale: float,
- alibi_slopes: Optional[torch.Tensor],
- k_scale: float,
- v_scale: float,
- tp_rank: int = 0,
- blocksparse_local_blocks: int = 0,
- blocksparse_vert_stride: int = 0,
- blocksparse_block_size: int = 64,
- blocksparse_head_sliding_step: int = 0,
- ) -> torch.Tensor:
- if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
- # use blocksparse paged attention
- block_size = value_cache.size(-1)
- assert (blocksparse_block_size > 0 and
- blocksparse_block_size % block_size == 0), \
- (f"{blocksparse_block_size=} needs to be a multiple of"
- f"{block_size=} used in block_tables.")
- output = torch.empty_like(query)
- block_size = value_cache.shape[3]
- num_seqs, num_heads, head_size = query.shape
- max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
- _PARTITION_SIZE)
- # NOTE: We use a simple heuristic to decide whether to use
- # PagedAttention V1 or V2. If the number of partitions is 1, we use
- # V1 to avoid the overhead of reduction. Also, if the number of
- # sequences or heads is large, we use V1 since there is enough work
- # to parallelize.
- # TODO: Tune this heuristic.
- # For context len > 8192, use V2 kernel to avoid shared memory shortage.
- use_v1 = (max_seq_len <= 8192
- and (max_num_partitions == 1 or num_seqs * num_heads > 512))
- if use_v1:
- # Run PagedAttention V1.
- ops.paged_attention_v1(
- output,
- query,
- key_cache,
- value_cache,
- num_kv_heads,
- scale,
- block_tables,
- seq_lens,
- block_size,
- max_seq_len,
- alibi_slopes,
- kv_cache_dtype,
- k_scale,
- v_scale,
- tp_rank,
- blocksparse_local_blocks,
- blocksparse_vert_stride,
- blocksparse_block_size,
- blocksparse_head_sliding_step,
- )
- else:
- # Run PagedAttention V2.
- assert _PARTITION_SIZE % block_size == 0
- tmp_output = torch.empty(
- size=(num_seqs, num_heads, max_num_partitions, head_size),
- dtype=output.dtype,
- device=output.device,
- )
- exp_sums = torch.empty(
- size=(num_seqs, num_heads, max_num_partitions),
- dtype=torch.float32,
- device=output.device,
- )
- max_logits = torch.empty_like(exp_sums)
- ops.paged_attention_v2(
- output,
- exp_sums,
- max_logits,
- tmp_output,
- query,
- key_cache,
- value_cache,
- num_kv_heads,
- scale,
- block_tables,
- seq_lens,
- block_size,
- max_seq_len,
- alibi_slopes,
- kv_cache_dtype,
- k_scale,
- v_scale,
- tp_rank,
- blocksparse_local_blocks,
- blocksparse_vert_stride,
- blocksparse_block_size,
- blocksparse_head_sliding_step,
- )
- return output
- @staticmethod
- def forward_prefix(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- kv_cache_dtype: str,
- key_cache: torch.Tensor,
- value_cache: torch.Tensor,
- block_tables: torch.Tensor,
- query_start_loc: torch.Tensor,
- seq_lens_tensor: torch.Tensor,
- context_lens: torch.Tensor,
- max_query_len: int,
- alibi_slopes: Optional[torch.Tensor],
- sliding_window: Optional[int],
- k_scale: float,
- v_scale: float,
- ) -> torch.Tensor:
- output = torch.empty_like(query)
- context_attention_fwd(
- query,
- key,
- value,
- output,
- kv_cache_dtype,
- key_cache,
- value_cache,
- block_tables,
- # query_start_loc is (batch_size + 1,)
- query_start_loc[:-1],
- seq_lens_tensor,
- context_lens,
- max_query_len,
- k_scale,
- v_scale,
- alibi_slopes,
- sliding_window,
- )
- return output
- @staticmethod
- def swap_blocks(
- src_kv_cache: torch.Tensor,
- dst_kv_cache: torch.Tensor,
- src_to_dst: torch.Tensor,
- ) -> None:
- src_key_cache = src_kv_cache[0]
- dst_key_cache = dst_kv_cache[0]
- ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
- src_value_cache = src_kv_cache[1]
- dst_value_cache = dst_kv_cache[1]
- ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
- @staticmethod
- def copy_blocks(
- kv_caches: List[torch.Tensor],
- src_to_dists: torch.Tensor,
- ) -> None:
- key_caches = [kv_cache[0] for kv_cache in kv_caches]
- value_caches = [kv_cache[1] for kv_cache in kv_caches]
- ops.copy_blocks(key_caches, value_caches, src_to_dists)
|