abstract.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass, fields
  3. from enum import Enum, auto
  4. from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
  5. Tuple, Type, TypeVar)
  6. import torch
  7. if TYPE_CHECKING:
  8. from aphrodite.common.sequence import SequenceGroupMetadata
  9. from aphrodite.task_handler.model_runner_base import \
  10. ModelRunnerInputBuilderBase
  11. class AttentionType(Enum):
  12. DECODER = auto() # Decoder attention between previous layer Q/K/V
  13. ENCODER = auto() # Encoder attention between previous layer Q/K/V
  14. ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
  15. class AttentionBackend(ABC):
  16. """Abstract class for attention backends."""
  17. @staticmethod
  18. @abstractmethod
  19. def get_name() -> str:
  20. raise NotImplementedError
  21. @staticmethod
  22. @abstractmethod
  23. def get_impl_cls() -> Type["AttentionImpl"]:
  24. raise NotImplementedError
  25. @staticmethod
  26. @abstractmethod
  27. def get_metadata_cls() -> Type["AttentionMetadata"]:
  28. raise NotImplementedError
  29. @classmethod
  30. def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
  31. return cls.get_metadata_cls()(*args, **kwargs)
  32. @staticmethod
  33. @abstractmethod
  34. def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
  35. raise NotImplementedError
  36. @classmethod
  37. def make_metadata_builder(cls, *args,
  38. **kwargs) -> "AttentionMetadataBuilder":
  39. return cls.get_builder_cls()(*args, **kwargs)
  40. @staticmethod
  41. @abstractmethod
  42. def get_kv_cache_shape(
  43. num_blocks: int,
  44. block_size: int,
  45. num_kv_heads: int,
  46. head_size: int,
  47. ) -> Tuple[int, ...]:
  48. raise NotImplementedError
  49. @staticmethod
  50. @abstractmethod
  51. def swap_blocks(
  52. src_kv_cache: torch.Tensor,
  53. dst_kv_cache: torch.Tensor,
  54. src_to_dst: torch.Tensor,
  55. ) -> None:
  56. raise NotImplementedError
  57. @staticmethod
  58. @abstractmethod
  59. def copy_blocks(
  60. kv_caches: List[torch.Tensor],
  61. src_to_dists: torch.Tensor,
  62. ) -> None:
  63. raise NotImplementedError
  64. @dataclass
  65. class AttentionMetadata:
  66. """Attention metadata for prefill and decode batched together."""
  67. # Total number of prefill requests.
  68. num_prefills: int
  69. # Number of prefill tokens.
  70. num_prefill_tokens: int
  71. # Number of decode tokens. Note that it is equivalent to the number of
  72. # decode requests.
  73. num_decode_tokens: int
  74. # (num_tokens,). The indices of the token slots that input tokens will be
  75. # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
  76. # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
  77. # in block 0, and 1st slot in block 1, respectively.
  78. slot_mapping: torch.Tensor
  79. @property
  80. @abstractmethod
  81. def prefill_metadata(self) -> Optional["AttentionMetadata"]:
  82. """Return the attention metadata that's required to run prefill
  83. attention."""
  84. pass
  85. @property
  86. @abstractmethod
  87. def decode_metadata(self) -> Optional["AttentionMetadata"]:
  88. """Return the attention metadata that's required to run decode
  89. attention."""
  90. pass
  91. def asdict_zerocopy(self,
  92. skip_fields: Optional[Set[str]] = None
  93. ) -> Dict[str, Any]:
  94. """Similar to dataclasses.asdict, but avoids deepcopying."""
  95. if skip_fields is None:
  96. skip_fields = set()
  97. # Note that if we add dataclasses as fields, they will need
  98. # similar handling.
  99. return {
  100. field.name: getattr(self, field.name)
  101. for field in fields(self) if field.name not in skip_fields
  102. }
  103. T = TypeVar("T", bound=AttentionMetadata)
  104. class AttentionMetadataBuilder(ABC, Generic[T]):
  105. """Abstract class for attention metadata builders."""
  106. @abstractmethod
  107. def __init__(self, input_builder) -> None:
  108. raise NotImplementedError
  109. @abstractmethod
  110. def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata",
  111. token_lens: List[int], seq_lens: List[int],
  112. curr_seq_lens: List[int], query_lens: List[int],
  113. context_lens: List[int],
  114. curr_sliding_window_blocks: List[int],
  115. prefix_cache_hit: bool, chunked_prefill_enabled: bool):
  116. """Add a sequence group to the metadata and update
  117. corresponding fields (in Python objects).
  118. """
  119. raise NotImplementedError
  120. @abstractmethod
  121. def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int],
  122. query_lens: List[int], cuda_graph_pad_size: int,
  123. batch_size: int) -> T:
  124. """Build attention metadata with on-device tensors."""
  125. raise NotImplementedError
  126. class AttentionImpl(ABC, Generic[T]):
  127. @abstractmethod
  128. def __init__(
  129. self,
  130. num_heads: int,
  131. head_size: int,
  132. scale: float,
  133. num_kv_heads: Optional[int] = None,
  134. alibi_slopes: Optional[List[float]] = None,
  135. sliding_window: Optional[int] = None,
  136. kv_cache_dtype: str = "auto",
  137. blocksparse_params: Optional[Dict[str, Any]] = None,
  138. ) -> None:
  139. raise NotImplementedError
  140. @abstractmethod
  141. def forward(
  142. self,
  143. query: torch.Tensor,
  144. key: torch.Tensor,
  145. value: torch.Tensor,
  146. kv_cache: torch.Tensor,
  147. attn_metadata: T,
  148. k_scale: float = 1.0,
  149. v_scale: float = 1.0,
  150. attn_type: AttentionType = AttentionType.DECODER,
  151. ) -> torch.Tensor:
  152. raise NotImplementedError