123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- import importlib
- import os
- from abc import ABC, abstractmethod
- from typing import Dict, List, Set, Tuple
- from loguru import logger
- from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
- from aphrodite.common.utils import (enable_trace_function_call_for_thread,
- update_environment_variables)
- from aphrodite.lora.request import LoRARequest
- class WorkerBase(ABC):
- """Worker interface that allows Aphrodite to cleanly separate
- implementations for different hardware.
- """
- @abstractmethod
- def init_device(self) -> None:
- """Initialize device state, such as loading the model or other on-device
- memory allocations.
- """
- raise NotImplementedError
- @abstractmethod
- def determine_num_available_blocks(self) -> Tuple[int, int]:
- """Determine the number of available blocks for the GPU KV cache and
- swappable CPU KV cache.
- The implementation may run profiling or other heuristics to determine
- the size of caches.
- Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
- are blocks that are "active" on the device and can be appended to.
- num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
- appended to.
- """
- raise NotImplementedError
- @abstractmethod
- def initialize_cache(self, num_gpu_blocks: int,
- num_cpu_blocks: int) -> None:
- """Initialize the KV cache with the given size in blocks.
- """
- raise NotImplementedError
- @abstractmethod
- def execute_model(
- self, seq_group_metadata_list: List[SequenceGroupMetadata],
- blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
- int],
- blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
- """Executes at least one model step on the given sequences, unless no
- sequences are provided."""
- raise NotImplementedError
- @abstractmethod
- def get_cache_block_size_bytes(self) -> int:
- """Return the size of a single cache block, in bytes. Used in
- speculative decoding.
- """
- raise NotImplementedError
- @abstractmethod
- def add_lora(self, lora_request: LoRARequest) -> bool:
- raise NotImplementedError
- @abstractmethod
- def remove_lora(self, lora_id: int) -> bool:
- raise NotImplementedError
- @abstractmethod
- def list_loras(self) -> Set[int]:
- raise NotImplementedError
- class LoraNotSupportedWorkerBase(WorkerBase):
- """Partial implementation of WorkerBase that raises exceptions when LoRA
- methods are invoked.
- """
- def add_lora(self, lora_request: LoRARequest) -> bool:
- raise ValueError(f"{type(self)} does not support LoRA")
- def remove_lora(self, lora_id: int) -> bool:
- raise ValueError(f"{type(self)} does not support LoRA")
- def list_loras(self) -> Set[int]:
- raise ValueError(f"{type(self)} does not support LoRA")
- class WorkerWrapperBase:
- """
- The whole point of this class is to lazily initialize the worker.
- We first instantiate the WorkerWrapper, which remembers the worker module
- and class name. Then, when we call `update_environment_variables`, and the
- real initialization happens in `init_worker`.
- """
- def __init__(self,
- worker_module_name=None,
- worker_class_name=None,
- trust_remote_code: bool = False) -> None:
- self.worker_module_name = worker_module_name
- self.worker_class_name = worker_class_name
- self.worker = None
- if trust_remote_code:
- # note: lazy import to avoid importing torch before initializing
- from aphrodite.common.utils import init_cached_hf_modules
- init_cached_hf_modules()
- @staticmethod
- def update_environment_variables(envs: Dict[str, str]) -> None:
- key = 'CUDA_VISIBLE_DEVICES'
- if key in envs and key in os.environ:
- # overwriting CUDA_VISIBLE_DEVICES is desired behavior
- # suppress the warning in `update_environment_variables`
- del os.environ[key]
- update_environment_variables(envs)
- def init_worker(self, *args, **kwargs):
- """
- Actual initialization of the worker class, and set up
- function tracing if required.
- Arguments are passed to the worker class constructor.
- """
- enable_trace_function_call_for_thread()
- mod = importlib.import_module(self.worker_module_name)
- worker_class = getattr(mod, self.worker_class_name)
- self.worker = worker_class(*args, **kwargs)
- def execute_method(self, method, *args, **kwargs):
- try:
- target = self if self.worker is None else self.worker
- executor = getattr(target, method)
- return executor(*args, **kwargs)
- except Exception as e:
- # if the driver worker also execute methods,
- # exceptions in the rest worker may cause deadlock in rpc like ray
- # print the error and inform the user to solve the error
- msg = (f"Error executing method {method}. "
- "This might cause deadlock in distributed execution.")
- logger.exception(msg)
- raise e
|