123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- from abc import ABC, abstractmethod
- from typing import Dict, List
- from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
- from aphrodite.lora.request import LoRARequest
- class WorkerBase(ABC):
- """Worker interface that allows Aphrodite to cleanly separate
- implementations for different hardware.
- """
- @abstractmethod
- def init_device(self) -> None:
- """Initialize device state, such as loading the model or other on-device
- memory allocations.
- """
- raise NotImplementedError
- @abstractmethod
- def determine_num_available_blocks(self) -> tuple[int, int]:
- """Determine the number of available blocks for the GPU KV cache and
- swappable CPU KV cache.
- The implementation may run profiling or other heuristics to determine
- the size of caches.
- Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
- are blocks that are "active" on the device and can be appended to.
- num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
- appended to.
- """
- raise NotImplementedError
- @abstractmethod
- def initialize_cache(self, num_gpu_blocks: int,
- num_cpu_blocks: int) -> None:
- """Initialize the KV cache with the given size in blocks.
- """
- raise NotImplementedError
- @abstractmethod
- def execute_model(self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- blocks_to_swap_in: Dict[int, int],
- blocks_to_swap_out: Dict[int, int],
- blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
- """Executes one model step on the given sequences."""
- raise NotImplementedError
- @abstractmethod
- def get_cache_block_size_bytes() -> int:
- """Return the size of a single cache block, in bytes. Used in
- speculative decoding.
- """
- raise NotImplementedError
- @abstractmethod
- def add_lora(self, lora_request: LoRARequest) -> bool:
- raise NotImplementedError
- @abstractmethod
- def remove_lora(self, lora_id: int) -> bool:
- raise NotImplementedError
- @abstractmethod
- def list_loras(self) -> List[int]:
- raise NotImplementedError
- class LoraNotSupportedWorkerBase(WorkerBase):
- """Partial implementation of WorkerBase that raises exceptions when LoRA
- methods are invoked.
- """
- def add_lora(self, lora_request: LoRARequest) -> bool:
- raise ValueError(f"{type(self)} does not support LoRA")
- def remove_lora(self, lora_id: int) -> bool:
- raise ValueError(f"{type(self)} does not support LoRA")
- def list_loras(self) -> List[int]:
- raise ValueError(f"{type(self)} does not support LoRA")
|