from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Tuple, Type import torch from aphrodite.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder) from aphrodite.attention.backends.utils import CommonAttentionState if TYPE_CHECKING: from aphrodite.worker.model_runner import ModelInputForGPUBuilder # Placeholder attention backend for models like Mamba and embedding models that # lack attention. class PlaceholderAttentionBackend(AttentionBackend): """Placeholder backend for when no attention is needed.""" @staticmethod def get_name() -> str: return "No attention" @staticmethod def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: return PlaceholderAttentionImpl @staticmethod def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: return PlaceholderAttentionMetadataBuilder @staticmethod def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: return PlaceholderAttentionMetadata @staticmethod def get_state_cls() -> Type["CommonAttentionState"]: return CommonAttentionState @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: return (1, 1, 1, 1, 1) @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: return @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: return @dataclass class PlaceholderAttentionMetadata(AttentionMetadata): """Attention metadata for prefill and decode batched together.""" # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks # in the kv cache. Each block can contain up to block_size tokens. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. block_tables: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None @property def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self.num_prefills == 0: return None if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata assert self.seq_lens is not None assert self.seq_lens_tensor is not None assert self.query_start_loc is not None assert self.context_lens_tensor is not None assert self.seq_start_loc is not None # Placeholders slot_mapping = torch.empty(0) block_tables = torch.empty(0) self._cached_prefill_metadata = PlaceholderAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, query_start_loc=self.query_start_loc[:self.num_prefills + 1], seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], context_lens_tensor=self.context_lens_tensor[:self.num_prefills], block_tables=block_tables, use_cuda_graph=False, ) return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self.num_decode_tokens == 0: return None if self._cached_decode_metadata is not None: return self._cached_decode_metadata assert self.seq_lens_tensor is not None # Placeholders slot_mapping = torch.empty(0) block_tables = torch.empty(0) self._cached_decode_metadata = PlaceholderAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, query_start_loc=None, seq_start_loc=None, context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, ) return self._cached_decode_metadata class PlaceholderAttentionMetadataBuilder( AttentionMetadataBuilder[PlaceholderAttentionMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.curr_seq_lens: List[int] = [] self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 self.input_builder = input_builder self.runner = input_builder.runner def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. """ is_prompt = inter_data.is_prompt for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], inter_data.orig_seq_lens, inter_data.seq_lens, inter_data.query_lens, inter_data.context_lens, inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: assert query_len == 1, ( "seq_len: {}, context_len: {}, query_len: {}".format( seq_len, context_len, query_len)) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors. Args: seq_lens: The maybe padded sequence lengths of the input sequences. query_lens: The query lengths of the input sequences. cuda_graph_pad_size: The padding size for cuda graph. -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 logits_soft_cap = getattr(self.runner.model_config.hf_config, "attn_logit_softcapping", None) if logits_soft_cap is not None: raise ValueError( "Please use Flashinfer backend for models with logits_soft_cap" " (i.e., Gemma-2). Otherwise, the output might be wrong." " Set Flashinfer backend by " "export APHRODITE_ATTENTION_BACKEND=FLASHINFER.") max_query_len = max(query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens if use_captured_graph: num_decode_tokens = batch_size assert max_query_len > 0, ("query_lens: {}".format(query_lens)) context_lens_tensor = torch.tensor(self.context_lens, dtype=torch.int, device=device) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) query_lens_tensor = torch.tensor(query_lens, dtype=torch.long, device=device) query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) torch.cumsum(query_lens_tensor, dim=0, dtype=query_start_loc.dtype, out=query_start_loc[1:]) # Placeholders slot_mapping = torch.empty(0) block_tables = torch.empty(0) return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) class PlaceholderAttentionImpl(AttentionImpl): def __init__(self, *args, **kwargs) -> None: return def forward(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError