123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- from abc import ABC, abstractmethod
- from dataclasses import dataclass, fields
- from enum import Enum, auto
- from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
- Tuple, Type, TypeVar)
- import torch
- if TYPE_CHECKING:
- from aphrodite.task_handler.model_runner_base import (
- ModelRunnerInputBuilderBase)
- class AttentionType(Enum):
- DECODER = auto() # Decoder attention between previous layer Q/K/V
- ENCODER = auto() # Encoder attention between previous layer Q/K/V
- ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
- class AttentionBackend(ABC):
- """Abstract class for attention backends."""
- @staticmethod
- @abstractmethod
- def get_name() -> str:
- raise NotImplementedError
- @staticmethod
- @abstractmethod
- def get_impl_cls() -> Type["AttentionImpl"]:
- raise NotImplementedError
- @staticmethod
- @abstractmethod
- def get_metadata_cls() -> Type["AttentionMetadata"]:
- raise NotImplementedError
- @classmethod
- def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
- return cls.get_metadata_cls()(*args, **kwargs)
- @staticmethod
- @abstractmethod
- def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
- raise NotImplementedError
- @classmethod
- def make_metadata_builder(cls, *args,
- **kwargs) -> "AttentionMetadataBuilder":
- return cls.get_builder_cls()(*args, **kwargs)
- @staticmethod
- @abstractmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- ) -> Tuple[int, ...]:
- raise NotImplementedError
- @staticmethod
- @abstractmethod
- def swap_blocks(
- src_kv_cache: torch.Tensor,
- dst_kv_cache: torch.Tensor,
- src_to_dst: torch.Tensor,
- ) -> None:
- raise NotImplementedError
- @staticmethod
- @abstractmethod
- def copy_blocks(
- kv_caches: List[torch.Tensor],
- src_to_dists: torch.Tensor,
- ) -> None:
- raise NotImplementedError
- @dataclass
- class AttentionMetadata:
- """Attention metadata for prefill and decode batched together."""
- # Total number of prefill requests.
- num_prefills: int
- # Number of prefill tokens.
- num_prefill_tokens: int
- # Number of decode tokens. Note that it is equivalent to the number of
- # decode requests.
- num_decode_tokens: int
- # (num_tokens,). The indices of the token slots that input tokens will be
- # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
- # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
- # in block 0, and 1st slot in block 1, respectively.
- slot_mapping: torch.Tensor
- @property
- @abstractmethod
- def prefill_metadata(self) -> Optional["AttentionMetadata"]:
- """Return the attention metadata that's required to run prefill
- attention."""
- pass
- @property
- @abstractmethod
- def decode_metadata(self) -> Optional["AttentionMetadata"]:
- """Return the attention metadata that's required to run decode
- attention."""
- pass
- def asdict_zerocopy(self,
- skip_fields: Optional[Set[str]] = None
- ) -> Dict[str, Any]:
- """Similar to dataclasses.asdict, but avoids deepcopying."""
- if skip_fields is None:
- skip_fields = set()
- # Note that if we add dataclasses as fields, they will need
- # similar handling.
- return {
- field.name: getattr(self, field.name)
- for field in fields(self) if field.name not in skip_fields
- }
- T = TypeVar("T", bound=AttentionMetadata)
- class AttentionMetadataBuilder(ABC, Generic[T]):
- """Abstract class for attention metadata builders."""
- @abstractmethod
- def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
- raise NotImplementedError
- @abstractmethod
- def build(self, seq_lens: List[int], query_lens: List[int],
- cuda_graph_pad_size: int, batch_size: int) -> T:
- """Build attention metadata with on-device tensors."""
- raise NotImplementedError
- class AttentionImpl(ABC, Generic[T]):
- @abstractmethod
- def __init__(
- self,
- num_heads: int,
- head_size: int,
- scale: float,
- num_kv_heads: Optional[int] = None,
- alibi_slopes: Optional[List[float]] = None,
- sliding_window: Optional[int] = None,
- kv_cache_dtype: str = "auto",
- blocksparse_params: Optional[Dict[str, Any]] = None,
- logits_soft_cap: Optional[float] = None,
- ) -> None:
- raise NotImplementedError
- @abstractmethod
- def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: T,
- k_scale: float = 1.0,
- v_scale: float = 1.0,
- attn_type: AttentionType = AttentionType.DECODER,
- ) -> torch.Tensor:
- raise NotImplementedError
|