worker_manager.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from contextlib import contextmanager
  2. from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
  3. import torch
  4. from aphrodite.adapter_commons.utils import (add_adapter_worker,
  5. apply_adapters_worker,
  6. list_adapters_worker,
  7. set_active_adapters_worker)
  8. from aphrodite.adapter_commons.worker_manager import AbstractWorkerManager
  9. from aphrodite.common.config import LoRAConfig
  10. from aphrodite.lora.models import (LoRAModel, LoRAModelManager,
  11. LRUCacheLoRAModelManager,
  12. create_lora_manager)
  13. from aphrodite.lora.request import LoRARequest
  14. from aphrodite.lora.utils import get_adapter_absolute_path
  15. class WorkerLoRAManager(AbstractWorkerManager):
  16. """WorkerLoRAManager that manages LoRA models on the worker side.
  17. Every request, the requested LoRAs will be loaded (unless they are already
  18. loaded), and every other LoRA will be unloaded."""
  19. _manager_cls: Type[LoRAModelManager] = LoRAModelManager
  20. def __init__(
  21. self,
  22. max_num_seqs: int,
  23. max_num_batched_tokens: int,
  24. vocab_size: int,
  25. lora_config: LoRAConfig,
  26. device: torch.device,
  27. embedding_modules: Dict[str, str],
  28. embedding_padding_modules: List[str],
  29. lora_model_cls: Type[LoRAModel] = LoRAModel,
  30. max_position_embeddings: Optional[int] = None,
  31. ):
  32. self._lora_model_cls = lora_model_cls
  33. self.embedding_modules = embedding_modules
  34. self.embedding_padding_modules = embedding_padding_modules
  35. self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
  36. self.max_num_seqs = max_num_seqs
  37. self.max_num_batched_tokens = max_num_batched_tokens
  38. self.vocab_size = vocab_size
  39. self.lora_config = lora_config
  40. self.max_position_embeddings = max_position_embeddings
  41. super().__init__(device)
  42. # Lazily initialized by create_lora_manager.
  43. self._adapter_manager: LoRAModelManager
  44. @contextmanager
  45. def dummy_lora_cache(self):
  46. """Use this context manager to reuse the dummy lora model
  47. to avoid creating it repeatedly."""
  48. self._cached_dummy_lora = None
  49. yield
  50. self._cached_dummy_lora = False
  51. @property
  52. def is_enabled(self) -> bool:
  53. return True
  54. def create_lora_manager(
  55. self,
  56. model: torch.nn.Module,
  57. ) -> Any:
  58. lora_manager = create_lora_manager(
  59. model,
  60. max_num_seqs=self.max_num_seqs,
  61. max_num_batched_tokens=self.max_num_batched_tokens,
  62. vocab_size=self.vocab_size,
  63. lora_config=self.lora_config,
  64. lora_manager_cls=self._manager_cls,
  65. )
  66. self._adapter_manager = lora_manager
  67. return lora_manager.model
  68. def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
  69. try:
  70. model = self._adapter_manager.model
  71. supported_lora_modules = model.supported_lora_modules
  72. packed_modules_mapping = model.packed_modules_mapping
  73. expected_lora_modules: List[str] = []
  74. for module in supported_lora_modules:
  75. if module in packed_modules_mapping:
  76. expected_lora_modules.extend(
  77. packed_modules_mapping[module])
  78. else:
  79. expected_lora_modules.append(module)
  80. lora_path = get_adapter_absolute_path(lora_request.lora_path)
  81. lora = self._lora_model_cls.from_local_checkpoint(
  82. lora_path,
  83. expected_lora_modules,
  84. max_position_embeddings=self.max_position_embeddings,
  85. lora_model_id=lora_request.lora_int_id,
  86. device="cpu",
  87. dtype=self.lora_config.lora_dtype,
  88. target_embedding_padding=self.vocab_size +
  89. self.lora_config.lora_extra_vocab_size,
  90. embedding_modules=self.embedding_modules,
  91. embedding_padding_modules=self.embedding_padding_modules,
  92. )
  93. except Exception as e:
  94. raise RuntimeError(f"Loading lora {lora_path} failed") from e
  95. if lora.rank > self.lora_config.max_lora_rank:
  96. raise ValueError(
  97. f"LoRA rank {lora.rank} is greater than max_lora_rank "
  98. f"{self.lora_config.max_lora_rank}. Please launch the "
  99. "engine with a higher max_lora_rank (--max-lora-rank).")
  100. if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
  101. raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
  102. f"is greater than lora_extra_vocab_size "
  103. f"{self.lora_config.lora_extra_vocab_size}.")
  104. return lora
  105. def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
  106. if lora_request.lora_int_id in self.list_adapters():
  107. return False
  108. if isinstance(self._cached_dummy_lora, LoRAModel):
  109. dummy_lora = self._cached_dummy_lora.clone(
  110. lora_request.lora_int_id)
  111. else:
  112. dummy_lora = self._adapter_manager.create_dummy_lora(
  113. lora_request.lora_int_id, rank, 1, self.embedding_modules)
  114. if self._cached_dummy_lora is None:
  115. self._cached_dummy_lora = dummy_lora
  116. return self._adapter_manager.add_adapter(dummy_lora)
  117. def pin_adapter(self, adapter_id: int) -> bool:
  118. return self._adapter_manager.pin_adapter(adapter_id)
  119. def set_active_adapters(self, requests: Set[Any],
  120. mapping: Optional[Any]) -> None:
  121. set_active_adapters_worker(requests, mapping, self._apply_adapters,
  122. self._adapter_manager.set_adapter_mapping)
  123. def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
  124. apply_adapters_worker(adapter_requests, self.list_adapters,
  125. self._adapter_manager.adapter_slots,
  126. self.remove_adapter, self.add_adapter)
  127. def add_adapter(self, adapter_request: Any) -> bool:
  128. return add_adapter_worker(adapter_request, self.list_adapters,
  129. self._load_adapter,
  130. self._adapter_manager.add_adapter,
  131. self._adapter_manager.activate_adapter)
  132. def remove_adapter(self, adapter_id: int) -> bool:
  133. return self._adapter_manager.remove_adapter(adapter_id)
  134. def remove_all_adapters(self):
  135. self._adapter_manager.remove_all_adapters()
  136. def list_adapters(self) -> Set[int]:
  137. return list_adapters_worker(self._adapter_manager.list_adapters)
  138. class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
  139. """WorkerLoRAManager that manages LoRA models on the worker side.
  140. Uses an LRU Cache. Every request, the requested LoRAs will be loaded
  141. (unless they are already loaded) and least recently used LoRAs will
  142. be unloaded if the cache is above capacity."""
  143. _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
  144. def create_lora_manager(
  145. self,
  146. model: torch.nn.Module,
  147. ) -> Any:
  148. lora_manager = create_lora_manager(
  149. model,
  150. lora_manager_cls=self._manager_cls,
  151. max_num_seqs=self.max_num_seqs,
  152. vocab_size=self.vocab_size,
  153. lora_config=self.lora_config,
  154. max_num_batched_tokens=self.max_num_batched_tokens,
  155. )
  156. self._adapter_manager = lora_manager
  157. return lora_manager.model
  158. def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
  159. loras_map = {
  160. lora_request.lora_int_id: lora_request
  161. for lora_request in lora_requests if lora_request
  162. }
  163. if len(loras_map) > self._adapter_manager.lora_slots:
  164. raise RuntimeError(
  165. f"Number of requested LoRAs ({len(loras_map)}) is greater "
  166. "than the number of GPU LoRA slots "
  167. f"({self._adapter_manager.lora_slots}).")
  168. for lora in loras_map.values():
  169. self.add_adapter(lora)
  170. def add_adapter(self, lora_request: LoRARequest) -> bool:
  171. if lora_request.lora_int_id not in self.list_adapters():
  172. # Remove before we load the new lora to save memory
  173. if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
  174. assert isinstance(self._adapter_manager,
  175. LRUCacheLoRAModelManager)
  176. self._adapter_manager.remove_oldest_adapter()
  177. lora = self._load_adapter(lora_request)
  178. loaded = self._adapter_manager.add_adapter(lora)
  179. else:
  180. # If the lora is already loaded, just touch it to
  181. # update its position in the caches
  182. loaded = self._adapter_manager.get_adapter(
  183. lora_request.lora_int_id) is not None
  184. self._adapter_manager.activate_adapter(lora_request.lora_int_id)
  185. return loaded