abstract.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass, fields
  3. from enum import Enum, auto
  4. from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
  5. Tuple, Type, TypeVar)
  6. import torch
  7. if TYPE_CHECKING:
  8. from aphrodite.task_handler.model_runner_base import (
  9. ModelRunnerInputBuilderBase)
  10. class AttentionType(Enum):
  11. DECODER = auto() # Decoder attention between previous layer Q/K/V
  12. ENCODER = auto() # Encoder attention between previous layer Q/K/V
  13. ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
  14. class AttentionBackend(ABC):
  15. """Abstract class for attention backends."""
  16. @staticmethod
  17. @abstractmethod
  18. def get_name() -> str:
  19. raise NotImplementedError
  20. @staticmethod
  21. @abstractmethod
  22. def get_impl_cls() -> Type["AttentionImpl"]:
  23. raise NotImplementedError
  24. @staticmethod
  25. @abstractmethod
  26. def get_metadata_cls() -> Type["AttentionMetadata"]:
  27. raise NotImplementedError
  28. @classmethod
  29. def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
  30. return cls.get_metadata_cls()(*args, **kwargs)
  31. @staticmethod
  32. @abstractmethod
  33. def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
  34. raise NotImplementedError
  35. @classmethod
  36. def make_metadata_builder(cls, *args,
  37. **kwargs) -> "AttentionMetadataBuilder":
  38. return cls.get_builder_cls()(*args, **kwargs)
  39. @staticmethod
  40. @abstractmethod
  41. def get_kv_cache_shape(
  42. num_blocks: int,
  43. block_size: int,
  44. num_kv_heads: int,
  45. head_size: int,
  46. ) -> Tuple[int, ...]:
  47. raise NotImplementedError
  48. @staticmethod
  49. @abstractmethod
  50. def swap_blocks(
  51. src_kv_cache: torch.Tensor,
  52. dst_kv_cache: torch.Tensor,
  53. src_to_dst: torch.Tensor,
  54. ) -> None:
  55. raise NotImplementedError
  56. @staticmethod
  57. @abstractmethod
  58. def copy_blocks(
  59. kv_caches: List[torch.Tensor],
  60. src_to_dists: torch.Tensor,
  61. ) -> None:
  62. raise NotImplementedError
  63. @dataclass
  64. class AttentionMetadata:
  65. """Attention metadata for prefill and decode batched together."""
  66. # Total number of prefill requests.
  67. num_prefills: int
  68. # Number of prefill tokens.
  69. num_prefill_tokens: int
  70. # Number of decode tokens. Note that it is equivalent to the number of
  71. # decode requests.
  72. num_decode_tokens: int
  73. # (num_tokens,). The indices of the token slots that input tokens will be
  74. # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
  75. # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
  76. # in block 0, and 1st slot in block 1, respectively.
  77. slot_mapping: torch.Tensor
  78. @property
  79. @abstractmethod
  80. def prefill_metadata(self) -> Optional["AttentionMetadata"]:
  81. """Return the attention metadata that's required to run prefill
  82. attention."""
  83. pass
  84. @property
  85. @abstractmethod
  86. def decode_metadata(self) -> Optional["AttentionMetadata"]:
  87. """Return the attention metadata that's required to run decode
  88. attention."""
  89. pass
  90. def asdict_zerocopy(self,
  91. skip_fields: Optional[Set[str]] = None
  92. ) -> Dict[str, Any]:
  93. """Similar to dataclasses.asdict, but avoids deepcopying."""
  94. if skip_fields is None:
  95. skip_fields = set()
  96. # Note that if we add dataclasses as fields, they will need
  97. # similar handling.
  98. return {
  99. field.name: getattr(self, field.name)
  100. for field in fields(self) if field.name not in skip_fields
  101. }
  102. T = TypeVar("T", bound=AttentionMetadata)
  103. class AttentionMetadataBuilder(ABC, Generic[T]):
  104. """Abstract class for attention metadata builders."""
  105. @abstractmethod
  106. def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
  107. raise NotImplementedError
  108. @abstractmethod
  109. def build(self, seq_lens: List[int], query_lens: List[int],
  110. cuda_graph_pad_size: int, batch_size: int) -> T:
  111. """Build attention metadata with on-device tensors."""
  112. raise NotImplementedError
  113. class AttentionImpl(ABC, Generic[T]):
  114. @abstractmethod
  115. def __init__(
  116. self,
  117. num_heads: int,
  118. head_size: int,
  119. scale: float,
  120. num_kv_heads: Optional[int] = None,
  121. alibi_slopes: Optional[List[float]] = None,
  122. sliding_window: Optional[int] = None,
  123. kv_cache_dtype: str = "auto",
  124. blocksparse_params: Optional[Dict[str, Any]] = None,
  125. logits_soft_cap: Optional[float] = None,
  126. ) -> None:
  127. raise NotImplementedError
  128. @abstractmethod
  129. def forward(
  130. self,
  131. query: torch.Tensor,
  132. key: torch.Tensor,
  133. value: torch.Tensor,
  134. kv_cache: torch.Tensor,
  135. attn_metadata: T,
  136. k_scale: float = 1.0,
  137. v_scale: float = 1.0,
  138. attn_type: AttentionType = AttentionType.DECODER,
  139. ) -> torch.Tensor:
  140. raise NotImplementedError