123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355 |
- import logging
- import math
- from typing import Any, Callable, Dict, List, Optional, Type
- import torch
- from torch import nn
- from aphrodite.adapter_commons.models import (AdapterLRUCache, AdapterModel,
- AdapterModelManager)
- from aphrodite.adapter_commons.utils import (add_adapter, deactivate_adapter,
- get_adapter, list_adapters,
- remove_adapter,
- set_adapter_mapping)
- from aphrodite.common.config import PromptAdapterConfig
- from aphrodite.prompt_adapter.layers import (
- VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
- from aphrodite.prompt_adapter.layers import PromptAdapterMapping
- from aphrodite.prompt_adapter.utils import load_peft_weights
- logger = logging.getLogger(__name__)
- _GLOBAL_PROMPT_ADAPTER_ID = 0
- def get_prompt_adapter_id():
- global _GLOBAL_PROMPT_ADAPTER_ID
- _GLOBAL_PROMPT_ADAPTER_ID += 1
- return _GLOBAL_PROMPT_ADAPTER_ID
- def convert_to_embedding_indices(indices):
- embedding_indices = []
- count = 0
- for value in indices:
- if value == -1:
- count = 0
- else:
- embedding_indices.append([value, count])
- count += 1
- return torch.tensor(embedding_indices)
- def convert_mapping(
- mapping: PromptAdapterMapping,
- prompt_adapter_index_to_id: List[Optional[int]],
- ) -> torch.Tensor:
- """Converts PromptAdapterMapping to index tensors.
- Args:
- mapping: PromptAdapterMapping mapping rows in a
- batch to PromptAdapter ids.
- prompt_adapter_index_to_id: List mapping PromptAdapter
- ids to PromptAdapter indices.
-
- Returns:
- pa_indices: Tensor of shape [batch_size] mapping batch rows to
- PromptAdapter indices.
- """
- id_to_index = {
- id_: idx
- for idx, id_ in enumerate(prompt_adapter_index_to_id)
- if id_ is not None
- }
- pa_indices = ([
- id_to_index.get(id_, -1) if id_ > 0 else -1
- for id_ in mapping.index_mapping
- ])
- pa_embedding_mapping = convert_to_embedding_indices(pa_indices)
- pa_indices = torch.tensor(pa_indices)
- return pa_indices, pa_embedding_mapping
- class PromptAdapterModel(AdapterModel):
- def __init__(self,
- prompt_adapter_id=None,
- num_virtual_tokens=None,
- prompt_embedding=None) -> None:
- self.id = prompt_adapter_id
- self.prompt_embedding = prompt_embedding
- self.num_virtual_tokens = num_virtual_tokens
- @classmethod
- def from_local_checkpoint(
- cls,
- adapter_model_path: str,
- prompt_adapter_id: int,
- num_virtual_tokens: int,
- config: PromptAdapterConfig,
- device: str = "cuda",
- ) -> "PromptAdapterModel":
- if num_virtual_tokens > config.max_prompt_adapter_token:
- raise ValueError(
- f'num_virtual_tokens ({num_virtual_tokens}) should be <= '
- f'max_prompt_adapter_token({config.max_prompt_adapter_token})')
- adapters_weights = load_peft_weights(adapter_model_path, device)
- prompt_embedding = adapters_weights["prompt_embeddings"].to(
- config.prompt_adapter_dtype)
- return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding)
- class PromptAdapterModelManager(AdapterModelManager):
- """A manager that manages multiple Prompt Adapter models."""
- def __init__(
- self,
- model: nn.Module,
- max_num_seqs: int,
- max_num_batched_tokens: int,
- prompt_adapter_config: PromptAdapterConfig,
- ):
- """Create a PromptAdapterModel and adapter for a given model.
- Args:
- model: the model to be adapted.
- max_num_seqs: the maximum number of sequences model can run in a
- single batch.
- max_num_batched_tokens: the maximum number of tokens model can run
- in a single batch.
- prompt_adapter_config: the PromptAdapter config,
- """
- self.model: nn.Module = model
- # Dict instead of a Set for compatibility with LRUCache.
- self.prompt_adapter_index_to_id: List[
- Optional[int]] = [None] * self.prompt_adapter_slots
- self.max_num_seqs = max_num_seqs
- self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
- self.prompt_adapter_config = prompt_adapter_config
- self.model.prompt_adapter_manager = self
- self.adapter_type = 'PromptAdapter'
- self.base_indices = torch.tensor([-1])
- self.base_embedding_indices = torch.tensor([])
- self.modules: Dict[str, nn.Module] = {}
- self._create_prompt_adapter_modules()
- self._last_mapping: Optional[PromptAdapterMapping] = None
- @property
- def prompt_adapter_slots(self) -> int:
- return self.prompt_adapter_config.max_prompt_adapters
- @property
- def adapter_slots(self) -> int:
- return self.prompt_adapter_slots
- @property
- def capacity(self) -> int:
- return self.prompt_adapter_config.max_cpu_prompt_adapters
- def activate_adapter(
- self,
- prompt_adapter_id: int,
- ) -> bool:
- """Move PromptAdapter into a GPU buffer
- to be used in the forward pass."""
- if prompt_adapter_id in self._active_adapters:
- return False
- first_free_slot = next(
- ((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate(
- self.prompt_adapter_index_to_id) if prompt_adapter_id is None),
- None)
- if first_free_slot is None:
- raise ValueError("No free prompt_adapter slots")
- index, _ = first_free_slot
- self._active_adapters[prompt_adapter_id] = None
- prompt_adapter_model = (self._registered_adapters[prompt_adapter_id])
- logger.debug(f"Activating prompt_adapter. int id: "
- f"{prompt_adapter_model.id}, "
- f"slot index: {index}")
- self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id
- for _, v in self.modules.items():
- v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding)
- return True
- def _deactivate_adapter(self, prompt_adapter_id: int):
- try:
- index = self.prompt_adapter_index_to_id.index(prompt_adapter_id)
- self.prompt_adapter_index_to_id[index] = None
- for _, v in self.modules.items():
- v.reset_prompt_adapter(index)
- except ValueError:
- pass
- def _add_adapter(self, prompt_adapter: PromptAdapterModel):
- self._registered_adapters[prompt_adapter.id] = prompt_adapter
- def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
- base_indices, base_embedding_indices = convert_mapping(
- mapping, self.prompt_adapter_index_to_id)
- for k, v in self.modules.items():
- v.set_mapping(base_indices, base_embedding_indices)
- def _create_prompt_adapter_modules(self):
- for module_name, module in self.model.named_modules(
- remove_duplicate=False):
- if "VocabParallel" in module.__class__.__name__:
- new_module = VocabParallelEmbeddingWithPromptAdapter(module)
- new_module.create_prompt_adapter_weights(
- self.prompt_adapter_config)
- replaced_module = self.replace_submodule(
- self.model, module_name, new_module)
- self.register_module(module.__class__.__name__,
- replaced_module)
- replaced_module.set_mapping(self.base_indices,
- self.base_embedding_indices)
- break
- def replace_submodule(self, model: nn.Module, module_name: str,
- new_module: nn.Module) -> nn.Module:
- """Replace a submodule in a model with a new module."""
- parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
- target_name = module_name.split(".")[-1]
- setattr(parent, target_name, new_module)
- return new_module
- def register_module(self, module_name: str, module: nn.Module):
- self.modules[module_name] = module
- def pin_adapter(self, prompt_adapter_id: int) -> bool:
- """Pin a PromptAdapterModel in the manager cache."""
- raise NotImplementedError(
- "Pinning is not supported in PromptAdapterModelManager."
- "Use LRUCachePromptAdapterModelManager for pinning"
- ) # type: ignore
- def remove_all_adapters(self):
- """Remove all PromptAdapterModel from the manager."""
- self._registered_adapters.clear()
- self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots
- self._active_adapters.clear()
- def deactivate_adapter(self, adapter_id: int) -> bool:
- return deactivate_adapter(adapter_id, self._active_adapters,
- self._deactivate_adapter)
- def add_adapter(self, adapter: PromptAdapterModel) -> bool:
- return add_adapter(adapter, self._registered_adapters, self.capacity,
- self._add_adapter)
- def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
- self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
- self._set_adapter_mapping)
- def remove_adapter(self, adapter_id: int) -> bool:
- return remove_adapter(adapter_id, self._registered_adapters,
- self.deactivate_adapter)
- def list_adapters(self) -> Dict[int, Any]:
- return list_adapters(self._registered_adapters)
- def get_adapter(self, adapter_id: int) -> Optional[Any]:
- return get_adapter(adapter_id, self._registered_adapters)
- class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]):
- def __init__(self, capacity: int,
- deactivate_prompt_adapter_fn: Callable[[int], bool]):
- super().__init__(capacity, deactivate_prompt_adapter_fn)
- class LRUCachePromptAdapterModelManager(PromptAdapterModelManager):
- """A model manager that manages multiple prompt_adapters with LRU cache."""
- def __init__(
- self,
- model: nn.Module,
- max_num_seqs: int,
- max_num_batched_tokens: int,
- prompt_adapter_config: PromptAdapterConfig,
- ):
- self.prompt_adapter_config = prompt_adapter_config
- super().__init__(model, max_num_seqs, max_num_batched_tokens,
- prompt_adapter_config)
- self._registered_adapters = PromptAdapterLRUCache(
- self.capacity, self.deactivate_adapter)
- self._active_adapters = PromptAdapterLRUCache(
- self.prompt_adapter_slots, self._deactivate_adapter)
- def list_adapters(self) -> Dict[int, PromptAdapterModel]:
- """List all registered PromptAdapterModel."""
- return dict(self._registered_adapters.cache)
- def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool:
- """Add a PromptAdapterModel to the manager."""
- if prompt_adapter.id not in self._registered_adapters:
- self._add_adapter(prompt_adapter)
- was_added = True
- else:
- # We always touch to update the LRU cache order
- self._registered_adapters.touch(prompt_adapter.id)
- was_added = False
- return was_added
- def activate_adapter(
- self,
- prompt_adapter_id: int,
- ) -> bool:
- if prompt_adapter_id not in self._active_adapters and len(
- self._active_adapters) >= self.prompt_adapter_slots:
- self._active_adapters.remove_oldest()
- result = super().activate_adapter(prompt_adapter_id)
- # We always touch to update the LRU cache order
- self._active_adapters.touch(prompt_adapter_id)
- return result
- def remove_oldest_adapter(self) -> bool:
- if len(self._registered_adapters) > 0:
- self._registered_adapters.remove_oldest()
- return True
- return False
- def pin_adapter(self, prompt_adapter_id: int) -> bool:
- """Pin a PromptAdapterModel in the manager cache."""
- self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id)
- self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id)
- return True
- def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int):
- try:
- self._registered_adapters.pin(prompt_adapter_id)
- except ValueError as err:
- raise ValueError(
- "Pinning failed. "
- f"Prompt Adapter {prompt_adapter_id} is not registered."
- ) from err
- def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int):
- if prompt_adapter_id not in self._active_adapters:
- # move adapter to gpu if not already active
- self.activate_adapter(prompt_adapter_id)
- self._active_adapters.pin(prompt_adapter_id)
- def create_prompt_adapter_manager(
- model: nn.Module,
- max_num_seqs: int,
- max_num_batched_tokens: int,
- prompt_adapter_config: PromptAdapterConfig,
- prompt_adapter_manager_cls: Type[
- PromptAdapterModelManager] = PromptAdapterModelManager,
- **kwargs) -> PromptAdapterModelManager:
- """Create a PromptAdapterModel for a given model."""
- prompt_adapter_manager = prompt_adapter_manager_cls(
- model=model,
- max_num_seqs=max_num_seqs,
- max_num_batched_tokens=max_num_batched_tokens,
- prompt_adapter_config=prompt_adapter_config,
- **kwargs)
- return prompt_adapter_manager
|