models.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. import logging
  2. import math
  3. from typing import Any, Callable, Dict, List, Optional, Type
  4. import torch
  5. from torch import nn
  6. from aphrodite.adapter_commons.models import (AdapterLRUCache, AdapterModel,
  7. AdapterModelManager)
  8. from aphrodite.adapter_commons.utils import (add_adapter, deactivate_adapter,
  9. get_adapter, list_adapters,
  10. remove_adapter,
  11. set_adapter_mapping)
  12. from aphrodite.common.config import PromptAdapterConfig
  13. from aphrodite.prompt_adapter.layers import \
  14. VocabParallelEmbeddingWithPromptAdapter # yapf: disable
  15. from aphrodite.prompt_adapter.layers import PromptAdapterMapping
  16. logger = logging.getLogger(__name__)
  17. _GLOBAL_PROMPT_ADAPTER_ID = 0
  18. def get_prompt_adapter_id():
  19. global _GLOBAL_PROMPT_ADAPTER_ID
  20. _GLOBAL_PROMPT_ADAPTER_ID += 1
  21. return _GLOBAL_PROMPT_ADAPTER_ID
  22. def convert_to_embedding_indices(indices):
  23. embedding_indices = []
  24. count = 0
  25. for value in indices:
  26. if value == -1:
  27. count = 0
  28. else:
  29. embedding_indices.append([value, count])
  30. count += 1
  31. return torch.tensor(embedding_indices)
  32. def convert_mapping(
  33. mapping: PromptAdapterMapping,
  34. prompt_adapter_index_to_id: List[Optional[int]],
  35. ) -> torch.Tensor:
  36. """Converts PromptAdapterMapping to index tensors.
  37. Args:
  38. mapping: PromptAdapterMapping mapping rows in a
  39. batch to PromptAdapter ids.
  40. prompt_adapter_index_to_id: List mapping PromptAdapter
  41. ids to PromptAdapter indices.
  42. Returns:
  43. pa_indices: Tensor of shape [batch_size] mapping batch rows to
  44. PromptAdapter indices.
  45. """
  46. id_to_index = {
  47. id_: idx
  48. for idx, id_ in enumerate(prompt_adapter_index_to_id)
  49. if id_ is not None
  50. }
  51. pa_indices = ([
  52. id_to_index.get(id_, -1) if id_ > 0 else -1
  53. for id_ in mapping.index_mapping
  54. ])
  55. pa_embedding_mapping = convert_to_embedding_indices(pa_indices)
  56. pa_indices = torch.tensor(pa_indices)
  57. return pa_indices, pa_embedding_mapping
  58. class PromptAdapterModel(AdapterModel):
  59. def __init__(self,
  60. prompt_adapter_id=None,
  61. num_virtual_tokens=None,
  62. prompt_embedding=None) -> None:
  63. self.id = prompt_adapter_id
  64. self.prompt_embedding = prompt_embedding
  65. self.num_virtual_tokens = num_virtual_tokens
  66. @classmethod
  67. def from_local_checkpoint(
  68. cls,
  69. adapter_model_path: str,
  70. prompt_adapter_id: int,
  71. num_virtual_tokens: int,
  72. config: PromptAdapterConfig,
  73. device: str = "cuda",
  74. ) -> "PromptAdapterModel":
  75. from peft.utils import load_peft_weights
  76. if num_virtual_tokens > config.max_prompt_adapter_token:
  77. raise ValueError(
  78. f'num_virtual_tokens ({num_virtual_tokens}) should be <= '
  79. f'max_prompt_adapter_token({config.max_prompt_adapter_token})')
  80. adapters_weights = load_peft_weights(adapter_model_path, device)
  81. prompt_embedding = adapters_weights["prompt_embeddings"].to(
  82. config.prompt_adapter_dtype)
  83. return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding)
  84. class PromptAdapterModelManager(AdapterModelManager):
  85. """A manager that manages multiple Prompt Adapter models."""
  86. def __init__(
  87. self,
  88. model: nn.Module,
  89. max_num_seqs: int,
  90. max_num_batched_tokens: int,
  91. prompt_adapter_config: PromptAdapterConfig,
  92. ):
  93. """Create a PromptAdapterModel and adapter for a given model.
  94. Args:
  95. model: the model to be adapted.
  96. max_num_seqs: the maximum number of sequences model can run in a
  97. single batch.
  98. max_num_batched_tokens: the maximum number of tokens model can run
  99. in a single batch.
  100. prompt_adapter_config: the PromptAdapter config,
  101. """
  102. self.model: nn.Module = model
  103. # Dict instead of a Set for compatibility with LRUCache.
  104. self.prompt_adapter_index_to_id: List[
  105. Optional[int]] = [None] * self.prompt_adapter_slots
  106. self.max_num_seqs = max_num_seqs
  107. self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
  108. self.prompt_adapter_config = prompt_adapter_config
  109. self.model.prompt_adapter_manager = self
  110. self.adapter_type = 'PromptAdapter'
  111. self.base_indices = torch.tensor([-1])
  112. self.base_embedding_indices = torch.tensor([])
  113. self.modules: Dict[str, nn.Module] = {}
  114. self._create_prompt_adapter_modules()
  115. self._last_mapping: Optional[PromptAdapterMapping] = None
  116. @property
  117. def prompt_adapter_slots(self) -> int:
  118. return self.prompt_adapter_config.max_prompt_adapters
  119. @property
  120. def adapter_slots(self) -> int:
  121. return self.prompt_adapter_slots
  122. @property
  123. def capacity(self) -> int:
  124. return self.prompt_adapter_config.max_cpu_prompt_adapters
  125. def activate_adapter(
  126. self,
  127. prompt_adapter_id: int,
  128. ) -> bool:
  129. """Move PromptAdapter into a GPU buffer
  130. to be used in the forward pass."""
  131. if prompt_adapter_id in self._active_adapters:
  132. return False
  133. first_free_slot = next(
  134. ((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate(
  135. self.prompt_adapter_index_to_id) if prompt_adapter_id is None),
  136. None)
  137. if first_free_slot is None:
  138. raise ValueError("No free prompt_adapter slots")
  139. index, _ = first_free_slot
  140. self._active_adapters[prompt_adapter_id] = None
  141. prompt_adapter_model = (self._registered_adapters[prompt_adapter_id])
  142. logger.debug(f"Activating prompt_adapter. int id: "
  143. f"{prompt_adapter_model.id}, "
  144. f"slot index: {index}")
  145. self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id
  146. for _, v in self.modules.items():
  147. v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding)
  148. return True
  149. def _deactivate_adapter(self, prompt_adapter_id: int):
  150. try:
  151. index = self.prompt_adapter_index_to_id.index(prompt_adapter_id)
  152. self.prompt_adapter_index_to_id[index] = None
  153. for _, v in self.modules.items():
  154. v.reset_prompt_adapter(index)
  155. except ValueError:
  156. pass
  157. def _add_adapter(self, prompt_adapter: PromptAdapterModel):
  158. self._registered_adapters[prompt_adapter.id] = prompt_adapter
  159. def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
  160. base_indices, base_embedding_indices = convert_mapping(
  161. mapping, self.prompt_adapter_index_to_id)
  162. for k, v in self.modules.items():
  163. v.set_mapping(base_indices, base_embedding_indices)
  164. def _create_prompt_adapter_modules(self):
  165. for module_name, module in self.model.named_modules(
  166. remove_duplicate=False):
  167. if "VocabParallel" in module.__class__.__name__:
  168. new_module = VocabParallelEmbeddingWithPromptAdapter(module)
  169. new_module.create_prompt_adapter_weights(
  170. self.prompt_adapter_config)
  171. replaced_module = self.replace_submodule(
  172. self.model, module_name, new_module)
  173. self.register_module(module.__class__.__name__,
  174. replaced_module)
  175. replaced_module.set_mapping(self.base_indices,
  176. self.base_embedding_indices)
  177. break
  178. def replace_submodule(self, model: nn.Module, module_name: str,
  179. new_module: nn.Module) -> nn.Module:
  180. """Replace a submodule in a model with a new module."""
  181. parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
  182. target_name = module_name.split(".")[-1]
  183. setattr(parent, target_name, new_module)
  184. return new_module
  185. def register_module(self, module_name: str, module: nn.Module):
  186. self.modules[module_name] = module
  187. def pin_adapter(self, prompt_adapter_id: int) -> bool:
  188. """Pin a PromptAdapterModel in the manager cache."""
  189. raise NotImplementedError(
  190. "Pinning is not supported in PromptAdapterModelManager."
  191. "Use LRUCachePromptAdapterModelManager for pinning"
  192. ) # type: ignore
  193. def remove_all_adapters(self):
  194. """Remove all PromptAdapterModel from the manager."""
  195. self._registered_adapters.clear()
  196. self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots
  197. self._active_adapters.clear()
  198. def deactivate_adapter(self, adapter_id: int) -> bool:
  199. return deactivate_adapter(adapter_id, self._active_adapters,
  200. self._deactivate_adapter)
  201. def add_adapter(self, adapter: PromptAdapterModel) -> bool:
  202. return add_adapter(adapter, self._registered_adapters, self.capacity,
  203. self._add_adapter)
  204. def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
  205. self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
  206. self._set_adapter_mapping)
  207. def remove_adapter(self, adapter_id: int) -> bool:
  208. return remove_adapter(adapter_id, self._registered_adapters,
  209. self.deactivate_adapter)
  210. def list_adapters(self) -> Dict[int, Any]:
  211. return list_adapters(self._registered_adapters)
  212. def get_adapter(self, adapter_id: int) -> Optional[Any]:
  213. return get_adapter(adapter_id, self._registered_adapters)
  214. class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]):
  215. def __init__(self, capacity: int,
  216. deactivate_prompt_adapter_fn: Callable[[int], bool]):
  217. super().__init__(capacity, deactivate_prompt_adapter_fn)
  218. class LRUCachePromptAdapterModelManager(PromptAdapterModelManager):
  219. """A model manager that manages multiple prompt_adapters with LRU cache."""
  220. def __init__(
  221. self,
  222. model: nn.Module,
  223. max_num_seqs: int,
  224. max_num_batched_tokens: int,
  225. prompt_adapter_config: PromptAdapterConfig,
  226. ):
  227. self.prompt_adapter_config = prompt_adapter_config
  228. super().__init__(model, max_num_seqs, max_num_batched_tokens,
  229. prompt_adapter_config)
  230. self._registered_adapters = PromptAdapterLRUCache(
  231. self.capacity, self.deactivate_adapter)
  232. self._active_adapters = PromptAdapterLRUCache(
  233. self.prompt_adapter_slots, self._deactivate_adapter)
  234. def list_adapters(self) -> Dict[int, PromptAdapterModel]:
  235. """List all registered PromptAdapterModel."""
  236. return dict(self._registered_adapters.cache)
  237. def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool:
  238. """Add a PromptAdapterModel to the manager."""
  239. if prompt_adapter.id not in self._registered_adapters:
  240. self._add_adapter(prompt_adapter)
  241. was_added = True
  242. else:
  243. # We always touch to update the LRU cache order
  244. self._registered_adapters.touch(prompt_adapter.id)
  245. was_added = False
  246. return was_added
  247. def activate_adapter(
  248. self,
  249. prompt_adapter_id: int,
  250. ) -> bool:
  251. if prompt_adapter_id not in self._active_adapters and len(
  252. self._active_adapters) >= self.prompt_adapter_slots:
  253. self._active_adapters.remove_oldest()
  254. result = super().activate_adapter(prompt_adapter_id)
  255. # We always touch to update the LRU cache order
  256. self._active_adapters.touch(prompt_adapter_id)
  257. return result
  258. def remove_oldest_adapter(self) -> bool:
  259. if len(self._registered_adapters) > 0:
  260. self._registered_adapters.remove_oldest()
  261. return True
  262. return False
  263. def pin_adapter(self, prompt_adapter_id: int) -> bool:
  264. """Pin a PromptAdapterModel in the manager cache."""
  265. self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id)
  266. self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id)
  267. return True
  268. def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int):
  269. try:
  270. self._registered_adapters.pin(prompt_adapter_id)
  271. except ValueError as err:
  272. raise ValueError(
  273. "Pinning failed. "
  274. f"Prompt Adapter {prompt_adapter_id} is not registered."
  275. ) from err
  276. def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int):
  277. if prompt_adapter_id not in self._active_adapters:
  278. # move adapter to gpu if not already active
  279. self.activate_adapter(prompt_adapter_id)
  280. self._active_adapters.pin(prompt_adapter_id)
  281. def create_prompt_adapter_manager(
  282. model: nn.Module,
  283. max_num_seqs: int,
  284. max_num_batched_tokens: int,
  285. prompt_adapter_config: PromptAdapterConfig,
  286. prompt_adapter_manager_cls: Type[
  287. PromptAdapterModelManager] = PromptAdapterModelManager,
  288. **kwargs) -> PromptAdapterModelManager:
  289. """Create a PromptAdapterModel for a given model."""
  290. prompt_adapter_manager = prompt_adapter_manager_cls(
  291. model=model,
  292. max_num_seqs=max_num_seqs,
  293. max_num_batched_tokens=max_num_batched_tokens,
  294. prompt_adapter_config=prompt_adapter_config,
  295. **kwargs)
  296. return prompt_adapter_manager