123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309 |
- from dataclasses import dataclass
- from typing import TYPE_CHECKING, List, Optional, Tuple, Type
- import torch
- from aphrodite.attention.backends.abstract import (AttentionBackend,
- AttentionImpl,
- AttentionMetadata,
- AttentionMetadataBuilder)
- from aphrodite.attention.backends.utils import CommonAttentionState
- if TYPE_CHECKING:
- from aphrodite.worker.model_runner import ModelInputForGPUBuilder
- # Placeholder attention backend for models like Mamba and embedding models that
- # lack attention.
- class PlaceholderAttentionBackend(AttentionBackend):
- """Placeholder backend for when no attention is needed."""
- @staticmethod
- def get_name() -> str:
- return "No attention"
- @staticmethod
- def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
- return PlaceholderAttentionImpl
- @staticmethod
- def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
- return PlaceholderAttentionMetadataBuilder
- @staticmethod
- def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
- return PlaceholderAttentionMetadata
- @staticmethod
- def get_state_cls() -> Type["CommonAttentionState"]:
- return CommonAttentionState
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- ) -> Tuple[int, ...]:
- return (1, 1, 1, 1, 1)
- @staticmethod
- def swap_blocks(
- src_kv_cache: torch.Tensor,
- dst_kv_cache: torch.Tensor,
- src_to_dst: torch.Tensor,
- ) -> None:
- return
- @staticmethod
- def copy_blocks(
- kv_caches: List[torch.Tensor],
- src_to_dists: torch.Tensor,
- ) -> None:
- return
- @dataclass
- class PlaceholderAttentionMetadata(AttentionMetadata):
- """Attention metadata for prefill and decode batched together."""
- # (batch_size,). The sequence length per sequence. Sequence length means
- # the computed tokens + new tokens None if it is a decoding.
- seq_lens: Optional[List[int]]
- # seq_lens stored as a tensor.
- seq_lens_tensor: Optional[torch.Tensor]
- # Maximum query length in the batch. None for decoding.
- max_query_len: Optional[int]
- # Maximum sequence length among prefill batch. 0 if there are decoding
- # requests only.
- max_prefill_seq_len: int
- # Maximum sequence length among decode batch. 0 if there are prefill
- # requests only.
- max_decode_seq_len: int
- # (batch_size + 1,). The cumulative subquery lengths of the sequences in
- # the batch, used to index into subquery. E.g., if the subquery length
- # is [4, 6], it is [0, 4, 10].
- query_start_loc: Optional[torch.Tensor]
- # (batch_size + 1,). The cumulative sequence lengths of the sequences in
- # the batch, used to index into sequence. E.g., if the sequence length is
- # [4, 6], it is [0, 4, 10].
- seq_start_loc: Optional[torch.Tensor]
- # (batch_size,) A tensor of context lengths (tokens that are computed
- # so far).
- context_lens_tensor: Optional[torch.Tensor]
- # (batch_size, max_blocks_per_seq).
- # Block addresses per sequence. (Seq id -> list of physical block)
- # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
- # in the kv cache. Each block can contain up to block_size tokens.
- # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
- # captured.
- block_tables: Optional[torch.Tensor]
- # Whether or not if cuda graph is enabled.
- # Cuda-graph is currently enabled for decoding only.
- # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
- use_cuda_graph: bool
- _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
- _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
- @property
- def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
- if self.num_prefills == 0:
- return None
- if self._cached_prefill_metadata is not None:
- return self._cached_prefill_metadata
- assert self.seq_lens is not None
- assert self.seq_lens_tensor is not None
- assert self.query_start_loc is not None
- assert self.context_lens_tensor is not None
- assert self.seq_start_loc is not None
- # Placeholders
- slot_mapping = torch.empty(0)
- block_tables = torch.empty(0)
- self._cached_prefill_metadata = PlaceholderAttentionMetadata(
- num_prefills=self.num_prefills,
- num_prefill_tokens=self.num_prefill_tokens,
- num_decode_tokens=0,
- slot_mapping=slot_mapping,
- seq_lens=self.seq_lens[:self.num_prefills],
- seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
- max_query_len=self.max_query_len,
- max_prefill_seq_len=self.max_prefill_seq_len,
- max_decode_seq_len=0,
- query_start_loc=self.query_start_loc[:self.num_prefills + 1],
- seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
- context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
- block_tables=block_tables,
- use_cuda_graph=False,
- )
- return self._cached_prefill_metadata
- @property
- def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
- if self.num_decode_tokens == 0:
- return None
- if self._cached_decode_metadata is not None:
- return self._cached_decode_metadata
- assert self.seq_lens_tensor is not None
- # Placeholders
- slot_mapping = torch.empty(0)
- block_tables = torch.empty(0)
- self._cached_decode_metadata = PlaceholderAttentionMetadata(
- num_prefills=0,
- num_prefill_tokens=0,
- num_decode_tokens=self.num_decode_tokens,
- slot_mapping=slot_mapping,
- seq_lens=None,
- seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
- max_query_len=None,
- max_prefill_seq_len=0,
- max_decode_seq_len=self.max_decode_seq_len,
- query_start_loc=None,
- seq_start_loc=None,
- context_lens_tensor=None,
- block_tables=block_tables,
- use_cuda_graph=self.use_cuda_graph,
- )
- return self._cached_decode_metadata
- class PlaceholderAttentionMetadataBuilder(
- AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
- def __init__(self, input_builder: "ModelInputForGPUBuilder"):
- self.prefill_seq_lens: List[int] = []
- self.context_lens: List[int] = []
- self.curr_seq_lens: List[int] = []
- self.num_prefills = 0
- self.num_prefill_tokens = 0
- self.num_decode_tokens = 0
- self.input_builder = input_builder
- self.runner = input_builder.runner
- def _add_seq_group(
- self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
- chunked_prefill_enabled: bool):
- """Add a sequence group to the metadata. Specifically update/append
- 1. context length.
- """
- is_prompt = inter_data.is_prompt
- for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
- curr_sliding_window_block) in zip(
- inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
- inter_data.orig_seq_lens, inter_data.seq_lens,
- inter_data.query_lens, inter_data.context_lens,
- inter_data.curr_sliding_window_blocks):
- self.context_lens.append(context_len)
- if is_prompt:
- self.num_prefills += 1
- self.num_prefill_tokens += token_len
- self.prefill_seq_lens.append(seq_len)
- else:
- assert query_len == 1, (
- "seq_len: {}, context_len: {}, query_len: {}".format(
- seq_len, context_len, query_len))
- self.num_decode_tokens += query_len
- self.curr_seq_lens.append(curr_seq_len)
- def build(self, seq_lens: List[int], query_lens: List[int],
- cuda_graph_pad_size: int, batch_size: int):
- """Build attention metadata with on-device tensors.
- Args:
- seq_lens: The maybe padded sequence lengths of the input sequences.
- query_lens: The query lengths of the input sequences.
- cuda_graph_pad_size: The padding size for cuda graph.
- -1 if cuda graph is not used.
- batch_size: The maybe padded batch size.
- """
- for inter_data in self.input_builder.inter_data_list:
- self._add_seq_group(inter_data,
- self.input_builder.chunked_prefill_enabled)
- device = self.runner.device
- use_captured_graph = cuda_graph_pad_size != -1
- logits_soft_cap = getattr(self.runner.model_config.hf_config,
- "attn_logit_softcapping", None)
- if logits_soft_cap is not None:
- raise ValueError(
- "Please use Flashinfer backend for models with logits_soft_cap"
- " (i.e., Gemma-2). Otherwise, the output might be wrong."
- " Set Flashinfer backend by "
- "export APHRODITE_ATTENTION_BACKEND=FLASHINFER.")
- max_query_len = max(query_lens)
- max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
- max_decode_seq_len = max(self.curr_seq_lens, default=0)
- num_decode_tokens = self.num_decode_tokens
- if use_captured_graph:
- num_decode_tokens = batch_size
- assert max_query_len > 0, ("query_lens: {}".format(query_lens))
- context_lens_tensor = torch.tensor(self.context_lens,
- dtype=torch.int,
- device=device)
- seq_lens_tensor = torch.tensor(seq_lens,
- dtype=torch.int,
- device=device)
- query_lens_tensor = torch.tensor(query_lens,
- dtype=torch.long,
- device=device)
- query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
- dtype=torch.int32,
- device=device)
- seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
- dtype=torch.int32,
- device=device)
- torch.cumsum(seq_lens_tensor,
- dim=0,
- dtype=seq_start_loc.dtype,
- out=seq_start_loc[1:])
- torch.cumsum(query_lens_tensor,
- dim=0,
- dtype=query_start_loc.dtype,
- out=query_start_loc[1:])
- # Placeholders
- slot_mapping = torch.empty(0)
- block_tables = torch.empty(0)
- return PlaceholderAttentionMetadata(
- num_prefills=self.num_prefills,
- slot_mapping=slot_mapping,
- num_prefill_tokens=self.num_prefill_tokens,
- num_decode_tokens=num_decode_tokens,
- seq_lens=seq_lens,
- seq_lens_tensor=seq_lens_tensor,
- max_query_len=max_query_len,
- max_prefill_seq_len=max_prefill_seq_len,
- max_decode_seq_len=max_decode_seq_len,
- query_start_loc=query_start_loc,
- seq_start_loc=seq_start_loc,
- context_lens_tensor=context_lens_tensor,
- block_tables=block_tables,
- use_cuda_graph=use_captured_graph,
- )
- class PlaceholderAttentionImpl(AttentionImpl):
- def __init__(self, *args, **kwargs) -> None:
- return
- def forward(self, *args, **kwargs) -> torch.Tensor:
- raise NotImplementedError
|