1
0

abstract.py 6.7 KB


  1. from abc import ABC, abstractmethod
  2. from contextlib import contextmanager
  3. from dataclasses import dataclass, fields
  4. from enum import Enum, auto
  5. from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
  6. Tuple, Type, TypeVar)
  7. import torch
  8. if TYPE_CHECKING:
  9. from aphrodite.worker.model_runner_base import (
  10. ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase)
  11. class AttentionType(Enum):
  12. DECODER = auto() # Decoder attention between previous layer Q/K/V
  13. ENCODER = auto() # Encoder attention between previous layer Q/K/V
  14. ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
  15. class AttentionBackend(ABC):
  16. """Abstract class for attention backends."""
  17. @staticmethod
  18. @abstractmethod
  19. def get_name() -> str:
  20. raise NotImplementedError
  21. @staticmethod
  22. @abstractmethod
  23. def get_impl_cls() -> Type["AttentionImpl"]:
  24. raise NotImplementedError
  25. @staticmethod
  26. @abstractmethod
  27. def get_metadata_cls() -> Type["AttentionMetadata"]:
  28. raise NotImplementedError
  29. @staticmethod
  30. def get_state_cls() -> Type["AttentionState"]:
  31. raise NotImplementedError
  32. @classmethod
  33. def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
  34. return cls.get_metadata_cls()(*args, **kwargs)
  35. @staticmethod
  36. @abstractmethod
  37. def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
  38. raise NotImplementedError
  39. @classmethod
  40. def make_metadata_builder(cls, *args,
  41. **kwargs) -> "AttentionMetadataBuilder":
  42. return cls.get_builder_cls()(*args, **kwargs)
  43. @staticmethod
  44. @abstractmethod
  45. def get_kv_cache_shape(
  46. num_blocks: int,
  47. block_size: int,
  48. num_kv_heads: int,
  49. head_size: int,
  50. ) -> Tuple[int, ...]:
  51. raise NotImplementedError
  52. @staticmethod
  53. @abstractmethod
  54. def swap_blocks(
  55. src_kv_cache: torch.Tensor,
  56. dst_kv_cache: torch.Tensor,
  57. src_to_dst: torch.Tensor,
  58. ) -> None:
  59. raise NotImplementedError
  60. @staticmethod
  61. @abstractmethod
  62. def copy_blocks(
  63. kv_caches: List[torch.Tensor],
  64. src_to_dists: torch.Tensor,
  65. ) -> None:
  66. raise NotImplementedError
  67. def advance_step(self, model_input: "ModelRunnerInputBase",
  68. sampled_token_ids: Optional[torch.Tensor],
  69. block_size: int, num_seqs: int, num_queries: int) -> None:
  70. raise NotImplementedError
  71. @dataclass
  72. class AttentionMetadata:
  73. """Attention metadata for prefill and decode batched together."""
  74. # Total number of prefill requests.
  75. num_prefills: int
  76. # Number of prefill tokens.
  77. num_prefill_tokens: int
  78. # Number of decode tokens. Note that it is equivalent to the number of
  79. # decode requests.
  80. num_decode_tokens: int
  81. # (num_tokens,). The indices of the token slots that input tokens will be
  82. # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
  83. # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
  84. # in block 0, and 1st slot in block 1, respectively.
  85. slot_mapping: torch.Tensor
  86. @property
  87. @abstractmethod
  88. def prefill_metadata(self) -> Optional["AttentionMetadata"]:
  89. """Return the attention metadata that's required to run prefill
  90. attention."""
  91. pass
  92. @property
  93. @abstractmethod
  94. def decode_metadata(self) -> Optional["AttentionMetadata"]:
  95. """Return the attention metadata that's required to run decode
  96. attention."""
  97. pass
  98. def asdict_zerocopy(self,
  99. skip_fields: Optional[Set[str]] = None
  100. ) -> Dict[str, Any]:
  101. """Similar to dataclasses.asdict, but avoids deepcopying."""
  102. if skip_fields is None:
  103. skip_fields = set()
  104. # Note that if we add dataclasses as fields, they will need
  105. # similar handling.
  106. return {
  107. field.name: getattr(self, field.name)
  108. for field in fields(self) if field.name not in skip_fields
  109. }
  110. T = TypeVar("T", bound=AttentionMetadata)
  111. class AttentionState(ABC, Generic[T]):
  112. """Holds attention backend specific objects reused during the
  113. lifetime of the model runner.
  114. """
  115. @abstractmethod
  116. def __init__(self, runner: "ModelRunnerBase"):
  117. ...
  118. @abstractmethod
  119. @contextmanager
  120. def graph_capture(self, max_batch_size: int):
  121. """Context manager used when capturing a CUDA graph."""
  122. yield
  123. @abstractmethod
  124. def graph_clone(self, batch_size: int) -> "AttentionState[T]":
  125. """Clone attention state to save in CUDA graph metadata."""
  126. ...
  127. @abstractmethod
  128. def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
  129. """Get attention metadata for CUDA graph capture of batch_size."""
  130. ...
  131. @abstractmethod
  132. def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
  133. """Get attention-specific input buffers for CUDA graph capture."""
  134. ...
  135. @abstractmethod
  136. def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
  137. attn_metadata: T) -> None:
  138. """In-place modify input buffers dict for CUDA graph replay."""
  139. ...
  140. @abstractmethod
  141. def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
  142. """Prepare state for forward pass."""
  143. ...
  144. class AttentionMetadataBuilder(ABC, Generic[T]):
  145. """Abstract class for attention metadata builders."""
  146. @abstractmethod
  147. def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
  148. raise NotImplementedError
  149. @abstractmethod
  150. def build(self, seq_lens: List[int], query_lens: List[int],
  151. cuda_graph_pad_size: int, batch_size: int) -> T:
  152. """Build attention metadata with on-device tensors."""
  153. raise NotImplementedError
  154. class AttentionImpl(ABC, Generic[T]):
  155. @abstractmethod
  156. def __init__(
  157. self,
  158. num_heads: int,
  159. head_size: int,
  160. scale: float,
  161. num_kv_heads: Optional[int] = None,
  162. alibi_slopes: Optional[List[float]] = None,
  163. sliding_window: Optional[int] = None,
  164. kv_cache_dtype: str = "auto",
  165. blocksparse_params: Optional[Dict[str, Any]] = None,
  166. logits_soft_cap: Optional[float] = None,
  167. ) -> None:
  168. raise NotImplementedError
  169. @abstractmethod
  170. def forward(
  171. self,
  172. query: torch.Tensor,
  173. key: torch.Tensor,
  174. value: torch.Tensor,
  175. kv_cache: torch.Tensor,
  176. attn_metadata: T,
  177. k_scale: float = 1.0,
  178. v_scale: float = 1.0,
  179. attn_type: AttentionType = AttentionType.DECODER,
  180. ) -> torch.Tensor:
  181. raise NotImplementedError