interfaces.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from abc import ABC, abstractmethod, abstractproperty
  2. from typing import Dict, List, Optional, Protocol
  3. from aphrodite.common.utils import Device
  4. class Block(ABC):
  5. @abstractmethod
  6. def append_token_ids(self, token_ids: List[int]) -> None:
  7. pass
  8. @abstractproperty
  9. def block_id(self) -> Optional[int]:
  10. pass
  11. @abstractproperty
  12. def token_ids(self) -> List[int]:
  13. pass
  14. @abstractproperty
  15. def num_empty_slots(self) -> int:
  16. pass
  17. @abstractproperty
  18. def is_full(self) -> bool:
  19. pass
  20. @abstractproperty
  21. def prev_block(self) -> Optional["Block"]:
  22. pass
  23. class Factory(Protocol):
  24. @abstractmethod
  25. def __call__(
  26. self,
  27. prev_block: Optional["Block"],
  28. token_ids: List[int],
  29. block_size: int,
  30. allocator: "BlockAllocator",
  31. block_id: Optional[int] = None,
  32. ) -> "Block":
  33. pass
  34. class BlockAllocator(ABC):
  35. @abstractmethod
  36. def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
  37. pass
  38. @abstractmethod
  39. def allocate_immutable(self, prev_block: Optional[Block],
  40. token_ids: List[int]) -> Block:
  41. pass
  42. @abstractmethod
  43. def free(self, block: Block) -> None:
  44. pass
  45. @abstractmethod
  46. def fork(self, last_block: Block) -> List[Block]:
  47. pass
  48. @abstractmethod
  49. def get_num_free_blocks(self) -> int:
  50. pass
  51. @abstractproperty
  52. def all_block_ids(self) -> frozenset[int]:
  53. pass
  54. @abstractmethod
  55. def clear_copy_on_writes(self) -> Dict[int, List[int]]:
  56. pass
  57. @abstractmethod
  58. def mark_blocks_as_computed(self) -> None:
  59. pass
  60. @abstractmethod
  61. def get_common_computed_block_ids(
  62. self, seq_block_ids: List[List[int]]) -> List[int]:
  63. pass
  64. class NoFreeBlocksError(ValueError):
  65. pass
  66. class DeviceAwareBlockAllocator(BlockAllocator):
  67. @abstractmethod
  68. def allocate_mutable(self, prev_block: Optional[Block],
  69. device: Device) -> Block:
  70. pass
  71. @abstractmethod
  72. def allocate_immutable(self, prev_block: Optional[Block],
  73. token_ids: List[int], device: Device) -> Block:
  74. pass
  75. @abstractmethod
  76. def get_num_free_blocks(self, device: Device) -> int:
  77. pass