abstract.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass, fields
  3. from enum import Enum, auto
  4. from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
  5. Tuple, Type, TypeVar)
  6. import torch
  7. if TYPE_CHECKING:
  8. from aphrodite.task_handler.model_runner_base import (
  9. ModelRunnerInputBuilderBase)
  10. class AttentionType(Enum):
  11. DECODER = auto() # Decoder attention between previous layer Q/K/V
  12. ENCODER = auto() # Encoder attention between previous layer Q/K/V
  13. ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
  14. class AttentionBackend(ABC):
  15. """Abstract class for attention backends."""
  16. @staticmethod
  17. @abstractmethod
  18. def get_name() -> str:
  19. raise NotImplementedError
  20. @staticmethod
  21. @abstractmethod
  22. def get_impl_cls() -> Type["AttentionImpl"]:
  23. raise NotImplementedError
  24. @staticmethod
  25. @abstractmethod
  26. def get_metadata_cls() -> Type["AttentionMetadata"]:
  27. raise NotImplementedError
  28. @classmethod
  29. def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
  30. return cls.get_metadata_cls()(*args, **kwargs)
  31. @staticmethod
  32. @abstractmethod
  33. def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
  34. raise NotImplementedError
  35. @classmethod
  36. def make_metadata_builder(cls, *args,
  37. **kwargs) -> "AttentionMetadataBuilder":
  38. return cls.get_builder_cls()(*args, **kwargs)
  39. @staticmethod
  40. @abstractmethod
  41. def get_kv_cache_shape(
  42. num_blocks: int,
  43. block_size: int,
  44. num_kv_heads: int,
  45. head_size: int,
  46. ) -> Tuple[int, ...]:
  47. raise NotImplementedError
  48. @staticmethod
  49. @abstractmethod
  50. def swap_blocks(
  51. src_kv_cache: torch.Tensor,
  52. dst_kv_cache: torch.Tensor,
  53. src_to_dst: torch.Tensor,
  54. ) -> None:
  55. raise NotImplementedError
  56. @staticmethod
  57. @abstractmethod
  58. def copy_blocks(
  59. kv_caches: List[torch.Tensor],
  60. src_to_dists: torch.Tensor,
  61. ) -> None:
  62. raise NotImplementedError
  63. def advance_step(self, num_seqs: int, num_queries: int):
  64. raise NotImplementedError
  65. @dataclass
  66. class AttentionMetadata:
  67. """Attention metadata for prefill and decode batched together."""
  68. # Total number of prefill requests.
  69. num_prefills: int
  70. # Number of prefill tokens.
  71. num_prefill_tokens: int
  72. # Number of decode tokens. Note that it is equivalent to the number of
  73. # decode requests.
  74. num_decode_tokens: int
  75. # (num_tokens,). The indices of the token slots that input tokens will be
  76. # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
  77. # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
  78. # in block 0, and 1st slot in block 1, respectively.
  79. slot_mapping: torch.Tensor
  80. @property
  81. @abstractmethod
  82. def prefill_metadata(self) -> Optional["AttentionMetadata"]:
  83. """Return the attention metadata that's required to run prefill
  84. attention."""
  85. pass
  86. @property
  87. @abstractmethod
  88. def decode_metadata(self) -> Optional["AttentionMetadata"]:
  89. """Return the attention metadata that's required to run decode
  90. attention."""
  91. pass
  92. def asdict_zerocopy(self,
  93. skip_fields: Optional[Set[str]] = None
  94. ) -> Dict[str, Any]:
  95. """Similar to dataclasses.asdict, but avoids deepcopying."""
  96. if skip_fields is None:
  97. skip_fields = set()
  98. # Note that if we add dataclasses as fields, they will need
  99. # similar handling.
  100. return {
  101. field.name: getattr(self, field.name)
  102. for field in fields(self) if field.name not in skip_fields
  103. }
  104. T = TypeVar("T", bound=AttentionMetadata)
  105. class AttentionMetadataBuilder(ABC, Generic[T]):
  106. """Abstract class for attention metadata builders."""
  107. @abstractmethod
  108. def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
  109. raise NotImplementedError
  110. @abstractmethod
  111. def build(self, seq_lens: List[int], query_lens: List[int],
  112. cuda_graph_pad_size: int, batch_size: int) -> T:
  113. """Build attention metadata with on-device tensors."""
  114. raise NotImplementedError
  115. class AttentionImpl(ABC, Generic[T]):
  116. @abstractmethod
  117. def __init__(
  118. self,
  119. num_heads: int,
  120. head_size: int,
  121. scale: float,
  122. num_kv_heads: Optional[int] = None,
  123. alibi_slopes: Optional[List[float]] = None,
  124. sliding_window: Optional[int] = None,
  125. kv_cache_dtype: str = "auto",
  126. blocksparse_params: Optional[Dict[str, Any]] = None,
  127. logits_soft_cap: Optional[float] = None,
  128. ) -> None:
  129. raise NotImplementedError
  130. @abstractmethod
  131. def forward(
  132. self,
  133. query: torch.Tensor,
  134. key: torch.Tensor,
  135. value: torch.Tensor,
  136. kv_cache: torch.Tensor,
  137. attn_metadata: T,
  138. k_scale: float = 1.0,
  139. v_scale: float = 1.0,
  140. attn_type: AttentionType = AttentionType.DECODER,
  141. ) -> torch.Tensor:
  142. raise NotImplementedError