123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- from typing import Any, Optional, Set, Type
- import torch
- from aphrodite.adapter_commons.utils import (add_adapter_worker,
- apply_adapters_worker,
- list_adapters_worker,
- set_active_adapters_worker)
- from aphrodite.adapter_commons.worker_manager import AbstractWorkerManager
- from aphrodite.common.config import PromptAdapterConfig
- from aphrodite.prompt_adapter.models import (LRUCachePromptAdapterModelManager,
- PromptAdapterModel,
- PromptAdapterModelManager,
- create_prompt_adapter_manager)
- from aphrodite.prompt_adapter.request import PromptAdapterRequest
- class WorkerPromptAdapterManager(AbstractWorkerManager):
- """WorkerPromptAdapterManager that manages
- prompt_adapter models on the worker side.
- Every request, the requested prompt_adapters will be
- loaded (unless they are already loaded),
- and every other prompt_adapter will be unloaded."""
- _manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager
- def __init__(
- self,
- max_num_seqs: int,
- max_num_batched_tokens: int,
- device: torch.device,
- prompt_adapter_config: PromptAdapterConfig,
- prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel
- ):
- self._adapter_manager: PromptAdapterModelManager
- self.max_num_seqs = max_num_seqs
- self.max_num_batched_tokens = max_num_batched_tokens
- self._prompt_adapter_model_cls = prompt_adapter_model_cls
- self.prompt_adapter_config = prompt_adapter_config
- super().__init__(device)
- @property
- def is_enabled(self) -> bool:
- return True
- def create_prompt_adapter_manager(
- self,
- model: torch.nn.Module,
- ) -> Any:
- prompt_adapter_manager = create_prompt_adapter_manager(
- model,
- max_num_seqs=self.max_num_seqs,
- max_num_batched_tokens=self.max_num_batched_tokens,
- prompt_adapter_config=self.prompt_adapter_config,
- prompt_adapter_manager_cls=self._manager_cls,
- )
- self._adapter_manager = prompt_adapter_manager
- return prompt_adapter_manager.model
- def _load_adapter(
- self, prompt_adapter_request: PromptAdapterRequest
- ) -> PromptAdapterModel:
- try:
- prompt_adapter = (
- self._prompt_adapter_model_cls.from_local_checkpoint(
- prompt_adapter_request.prompt_adapter_local_path,
- prompt_adapter_id=prompt_adapter_request.prompt_adapter_id,
- num_virtual_tokens=prompt_adapter_request.
- prompt_adapter_num_virtual_tokens,
- config=self.prompt_adapter_config,
- device=str(self.device),
- ))
- except Exception as e:
- raise RuntimeError(
- f"Loading prompt_adapter "
- f"{prompt_adapter_request.prompt_adapter_local_path}"
- f" failed") from e
- return prompt_adapter
- def add_dummy_prompt_adapter(
- self, prompt_adapter_request: PromptAdapterRequest) -> bool:
- return True
- def pin_adapter(self, adapter_id: int) -> bool:
- return self._adapter_manager.pin_adapter(adapter_id)
- def set_active_adapters(self, requests: Set[Any],
- mapping: Optional[Any]) -> None:
- set_active_adapters_worker(requests, mapping, self._apply_adapters,
- self._adapter_manager.set_adapter_mapping)
- def add_adapter(self, adapter_request: Any) -> bool:
- return add_adapter_worker(adapter_request, self.list_adapters,
- self._load_adapter,
- self._adapter_manager.add_adapter,
- self._adapter_manager.activate_adapter)
- def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
- apply_adapters_worker(adapter_requests, self.list_adapters,
- self._adapter_manager.adapter_slots,
- self.remove_adapter, self.add_adapter)
- def remove_adapter(self, adapter_id: int) -> bool:
- return self._adapter_manager.remove_adapter(adapter_id)
- def remove_all_adapters(self):
- self._adapter_manager.remove_all_adapters()
- def list_adapters(self) -> Set[int]:
- return list_adapters_worker(self._adapter_manager.list_adapters)
- class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager):
- """WorkerPromptAdapterManager that manages
- prompt_adapter models on the worker side.
- Uses an LRU Cache. Every request, the requested
- prompt_adapters will be loaded (unless they are already loaded)
- and least recently used prompt_adapters will
- be unloaded if the cache is above capacity."""
- _prompt_adapter_manager_cls: Type[
- LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager
- def create_prompt_adapter_manager(
- self,
- model: torch.nn.Module,
- ) -> Any:
- prompt_adapter_manager = create_prompt_adapter_manager(
- model,
- max_num_seqs=self.max_num_seqs,
- max_num_batched_tokens=self.max_num_batched_tokens,
- prompt_adapter_config=self.prompt_adapter_config,
- prompt_adapter_manager_cls=self._prompt_adapter_manager_cls)
- self._adapter_manager: LRUCachePromptAdapterModelManager = (
- prompt_adapter_manager)
- return prompt_adapter_manager.model
- def _apply_adapters(
- self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None:
- prompt_adapters_map = {
- prompt_adapter_request.prompt_adapter_id: prompt_adapter_request
- for prompt_adapter_request in prompt_adapter_requests
- if prompt_adapter_request
- }
- if len(prompt_adapters_map
- ) > self._adapter_manager.prompt_adapter_slots:
- raise RuntimeError(
- f"Number of requested prompt_adapters "
- f"({len(prompt_adapters_map)}) is greater "
- "than the number of GPU prompt_adapter slots "
- f"({self._adapter_manager.prompt_adapter_slots}).")
- for prompt_adapter in prompt_adapters_map.values():
- self.add_adapter(prompt_adapter)
- def add_adapter(self,
- prompt_adapter_request: PromptAdapterRequest) -> bool:
- if prompt_adapter_request.prompt_adapter_id not in self.list_adapters(
- ):
- # Remove before we load the new prompt_adapter to save memory
- if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
- self._adapter_manager.remove_oldest_adapter()
- prompt_adapter = self._load_adapter(prompt_adapter_request)
- loaded = self._adapter_manager.add_adapter(prompt_adapter)
- else:
- # If the prompt_adapter is already loaded, just touch it to
- # update its position in the caches
- loaded = self._adapter_manager.get_adapter(
- prompt_adapter_request.prompt_adapter_id) is not None
- self._adapter_manager.activate_adapter(
- prompt_adapter_request.prompt_adapter_id)
- return loaded
|