1
0

worker_base.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from abc import ABC, abstractmethod
  2. from typing import Dict, List
  3. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  4. from aphrodite.lora.request import LoRARequest
  5. class WorkerBase(ABC):
  6. """Worker interface that allows Aphrodite to cleanly separate
  7. implementations for different hardware.
  8. """
  9. @abstractmethod
  10. def init_device(self) -> None:
  11. """Initialize device state, such as loading the model or other on-device
  12. memory allocations.
  13. """
  14. raise NotImplementedError
  15. @abstractmethod
  16. def determine_num_available_blocks(self) -> tuple[int, int]:
  17. """Determine the number of available blocks for the GPU KV cache and
  18. swappable CPU KV cache.
  19. The implementation may run profiling or other heuristics to determine
  20. the size of caches.
  21. Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
  22. are blocks that are "active" on the device and can be appended to.
  23. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
  24. appended to.
  25. """
  26. raise NotImplementedError
  27. @abstractmethod
  28. def initialize_cache(self, num_gpu_blocks: int,
  29. num_cpu_blocks: int) -> None:
  30. """Initialize the KV cache with the given size in blocks.
  31. """
  32. raise NotImplementedError
  33. @abstractmethod
  34. def execute_model(self,
  35. seq_group_metadata_list: List[SequenceGroupMetadata],
  36. blocks_to_swap_in: Dict[int, int],
  37. blocks_to_swap_out: Dict[int, int],
  38. blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
  39. """Executes one model step on the given sequences."""
  40. raise NotImplementedError
  41. @abstractmethod
  42. def get_cache_block_size_bytes() -> int:
  43. """Return the size of a single cache block, in bytes. Used in
  44. speculative decoding.
  45. """
  46. raise NotImplementedError
  47. @abstractmethod
  48. def add_lora(self, lora_request: LoRARequest) -> bool:
  49. raise NotImplementedError
  50. @abstractmethod
  51. def remove_lora(self, lora_id: int) -> bool:
  52. raise NotImplementedError
  53. @abstractmethod
  54. def list_loras(self) -> List[int]:
  55. raise NotImplementedError
  56. class LoraNotSupportedWorkerBase(WorkerBase):
  57. """Partial implementation of WorkerBase that raises exceptions when LoRA
  58. methods are invoked.
  59. """
  60. def add_lora(self, lora_request: LoRARequest) -> bool:
  61. raise ValueError(f"{type(self)} does not support LoRA")
  62. def remove_lora(self, lora_id: int) -> bool:
  63. raise ValueError(f"{type(self)} does not support LoRA")
  64. def list_loras(self) -> List[int]:
  65. raise ValueError(f"{type(self)} does not support LoRA")