abstract.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass, fields
  3. from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
  4. import torch
  5. class AttentionBackend(ABC):
  6. """Abstract class for attention backends."""
  7. @staticmethod
  8. @abstractmethod
  9. def get_impl_cls() -> Type["AttentionImpl"]:
  10. raise NotImplementedError
  11. @staticmethod
  12. @abstractmethod
  13. def make_metadata(*args, **kwargs) -> "AttentionMetadata":
  14. raise NotImplementedError
  15. @staticmethod
  16. @abstractmethod
  17. def get_kv_cache_shape(
  18. num_blocks: int,
  19. block_size: int,
  20. num_kv_heads: int,
  21. head_size: int,
  22. ) -> Tuple[int, ...]:
  23. raise NotImplementedError
  24. @staticmethod
  25. @abstractmethod
  26. def swap_blocks(
  27. src_kv_cache: torch.Tensor,
  28. dst_kv_cache: torch.Tensor,
  29. src_to_dst: Dict[int, int],
  30. ) -> None:
  31. raise NotImplementedError
  32. @staticmethod
  33. @abstractmethod
  34. def copy_blocks(
  35. kv_caches: List[torch.Tensor],
  36. src_to_dists: Dict[int, List[int]],
  37. ) -> None:
  38. raise NotImplementedError
  39. @dataclass
  40. class AttentionMetadataPerStage:
  41. """Attention metadata for a specific stage. I.e., prefill or decode."""
  42. def asdict_zerocopy(self) -> Dict[str, Any]:
  43. """Similar to dataclasses.asdict, but avoids deepcopying."""
  44. # Note that if we add dataclasses as fields, they will need
  45. # similar handling.
  46. return {
  47. field.name: getattr(self, field.name)
  48. for field in fields(self)
  49. }
  50. T = TypeVar("T", bound=AttentionMetadataPerStage)
  51. @dataclass
  52. class AttentionMetadata(Generic[T]):
  53. """Attention metadata for prefill and decode batched together."""
  54. # Total number of prefill requests.
  55. num_prefills: int
  56. # Number of prefill tokens.
  57. num_prefill_tokens: int
  58. # Number of decode tokens. Note that it is equivalent to the number of
  59. # decode requests.
  60. num_decode_tokens: int
  61. # The attention metadata for prefill requests in a batch.
  62. # None if there's no prefill requests in a batch.
  63. prefill_metadata: Optional[T]
  64. # The attention metadata for decode requests in a batch.
  65. # None if there's no decode requests in a batch.
  66. decode_metadata: Optional[T]
  67. # (num_tokens,). The indices of the token slots that input tokens will be
  68. # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
  69. # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
  70. # in block 0, and 1st slot in block 1, respectively.
  71. slot_mapping: torch.Tensor
  72. # The kv cache's data type.
  73. kv_cache_dtype: str
  74. def __post_init__(self):
  75. if self.num_prefill_tokens > 0:
  76. assert self.num_prefills > 0
  77. assert self.prefill_metadata is not None
  78. if self.num_decode_tokens > 0:
  79. assert self.decode_metadata is not None
  80. class AttentionImpl(ABC):
  81. @abstractmethod
  82. def __init__(
  83. self,
  84. num_heads: int,
  85. head_size: int,
  86. scale: float,
  87. num_kv_heads: Optional[int] = None,
  88. alibi_slopes: Optional[List[float]] = None,
  89. sliding_window: Optional[int] = None,
  90. ) -> None:
  91. raise NotImplementedError
  92. @abstractmethod
  93. def forward(
  94. self,
  95. query: torch.Tensor,
  96. key: torch.Tensor,
  97. value: torch.Tensor,
  98. kv_cache: torch.Tensor,
  99. attn_metadata: AttentionMetadata,
  100. kv_scale: float,
  101. ) -> torch.Tensor:
  102. raise NotImplementedError