worker_base.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from abc import ABC, abstractmethod
  2. from typing import Dict, List
  3. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  4. from aphrodite.lora.request import LoRARequest
  5. class WorkerBase(ABC):
  6. """Worker interface that allows Aphrodite to cleanly separate
  7. implementations for different hardware.
  8. """
  9. @abstractmethod
  10. def init_device(self) -> None:
  11. """Initialize device state, such as loading the model or other on-device
  12. memory allocations.
  13. """
  14. raise NotImplementedError
  15. @abstractmethod
  16. def determine_num_available_blocks(self) -> tuple[int, int]:
  17. """Determine the number of available blocks for the GPU KV cache and
  18. swappable CPU KV cache.
  19. The implementation may run profiling or other heuristics to determine
  20. the size of caches.
  21. Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
  22. are blocks that are "active" on the device and can be appended to.
  23. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
  24. appended to.
  25. """
  26. raise NotImplementedError
  27. @abstractmethod
  28. def initialize_cache(self, num_gpu_blocks: int,
  29. num_cpu_blocks: int) -> None:
  30. """Initialize the KV cache with the given size in blocks.
  31. """
  32. raise NotImplementedError
  33. @abstractmethod
  34. def execute_model(
  35. self, seq_group_metadata_list: List[SequenceGroupMetadata],
  36. blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
  37. int],
  38. blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
  39. """Executes at least one model step on the given sequences, unless no
  40. sequences are provided."""
  41. raise NotImplementedError
  42. @abstractmethod
  43. def get_cache_block_size_bytes() -> int:
  44. """Return the size of a single cache block, in bytes. Used in
  45. speculative decoding.
  46. """
  47. raise NotImplementedError
  48. @abstractmethod
  49. def add_lora(self, lora_request: LoRARequest) -> bool:
  50. raise NotImplementedError
  51. @abstractmethod
  52. def remove_lora(self, lora_id: int) -> bool:
  53. raise NotImplementedError
  54. @abstractmethod
  55. def list_loras(self) -> List[int]:
  56. raise NotImplementedError
  57. class LoraNotSupportedWorkerBase(WorkerBase):
  58. """Partial implementation of WorkerBase that raises exceptions when LoRA
  59. methods are invoked.
  60. """
  61. def add_lora(self, lora_request: LoRARequest) -> bool:
  62. raise ValueError(f"{type(self)} does not support LoRA")
  63. def remove_lora(self, lora_id: int) -> bool:
  64. raise ValueError(f"{type(self)} does not support LoRA")
  65. def list_loras(self) -> List[int]:
  66. raise ValueError(f"{type(self)} does not support LoRA")