worker_manager.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. from contextlib import contextmanager
  2. from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
  3. import torch
  4. from aphrodite.adapter_commons.utils import (add_adapter_worker,
  5. apply_adapters_worker,
  6. list_adapters_worker,
  7. set_active_adapters_worker)
  8. from aphrodite.adapter_commons.worker_manager import AbstractWorkerManager
  9. from aphrodite.common.config import LoRAConfig
  10. from aphrodite.lora.models import (LoRAModel, LoRAModelManager,
  11. LRUCacheLoRAModelManager,
  12. create_lora_manager)
  13. from aphrodite.lora.request import LoRARequest
  14. from aphrodite.lora.utils import get_adapter_absolute_path
  15. class WorkerLoRAManager(AbstractWorkerManager):
  16. """WorkerLoRAManager that manages LoRA models on the worker side.
  17. Every request, the requested LoRAs will be loaded (unless they are already
  18. loaded), and every other LoRA will be unloaded."""
  19. _manager_cls: Type[LoRAModelManager] = LoRAModelManager
  20. def __init__(
  21. self,
  22. max_num_seqs: int,
  23. max_num_batched_tokens: int,
  24. vocab_size: int,
  25. lora_config: LoRAConfig,
  26. device: torch.device,
  27. embedding_modules: Dict[str, str],
  28. embedding_padding_modules: List[str],
  29. lora_model_cls: Type[LoRAModel] = LoRAModel,
  30. max_position_embeddings: Optional[int] = None,
  31. ):
  32. self._lora_model_cls = lora_model_cls
  33. self.embedding_modules = embedding_modules
  34. self.embedding_padding_modules = embedding_padding_modules
  35. self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
  36. self.max_num_seqs = max_num_seqs
  37. self.max_num_batched_tokens = max_num_batched_tokens
  38. self.vocab_size = vocab_size
  39. self.lora_config = lora_config
  40. self.max_position_embeddings = max_position_embeddings
  41. super().__init__(device)
  42. # Lazily initialized by create_lora_manager.
  43. self._adapter_manager: LoRAModelManager
  44. @contextmanager
  45. def dummy_lora_cache(self):
  46. """Use this context manager to reuse the dummy lora model
  47. to avoid creating it repeatedly."""
  48. self._cached_dummy_lora = None
  49. yield
  50. self._cached_dummy_lora = False
  51. @property
  52. def is_enabled(self) -> bool:
  53. return True
  54. def create_lora_manager(
  55. self,
  56. model: torch.nn.Module,
  57. ) -> Any:
  58. lora_manager = create_lora_manager(
  59. model,
  60. max_num_seqs=self.max_num_seqs,
  61. max_num_batched_tokens=self.max_num_batched_tokens,
  62. vocab_size=self.vocab_size,
  63. lora_config=self.lora_config,
  64. lora_manager_cls=self._manager_cls,
  65. )
  66. self._adapter_manager = lora_manager
  67. return lora_manager.model
  68. def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
  69. try:
  70. model = self._adapter_manager.model
  71. supported_lora_modules = model.supported_lora_modules
  72. packed_modules_mapping = model.packed_modules_mapping
  73. expected_lora_modules: List[str] = []
  74. for module in supported_lora_modules:
  75. if module in packed_modules_mapping:
  76. expected_lora_modules.extend(
  77. packed_modules_mapping[module])
  78. else:
  79. expected_lora_modules.append(module)
  80. expected_modules_to_save: List[str] = model.modules_to_save
  81. lora_path = get_adapter_absolute_path(lora_request.lora_path)
  82. lora = self._lora_model_cls.from_local_checkpoint(
  83. lora_path,
  84. expected_lora_modules,
  85. expected_modules_to_save,
  86. max_position_embeddings=self.max_position_embeddings,
  87. lora_model_id=lora_request.lora_int_id,
  88. device="cpu",
  89. enable_lora_modules_to_save=self._adapter_manager.lora_config.
  90. enable_lora_modules_to_save,
  91. dtype=self.lora_config.lora_dtype,
  92. target_embedding_padding=self.vocab_size +
  93. self.lora_config.lora_extra_vocab_size,
  94. embedding_modules=self.embedding_modules,
  95. embedding_padding_modules=self.embedding_padding_modules,
  96. )
  97. except Exception as e:
  98. raise RuntimeError(f"Loading lora {lora_path} failed") from e
  99. if lora.rank > self.lora_config.max_lora_rank:
  100. raise ValueError(
  101. f"LoRA rank {lora.rank} is greater than max_lora_rank "
  102. f"{self.lora_config.max_lora_rank}. Please launch the "
  103. "engine with a higher max_lora_rank (--max-lora-rank).")
  104. if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
  105. raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
  106. f"is greater than lora_extra_vocab_size "
  107. f"{self.lora_config.lora_extra_vocab_size}.")
  108. return lora
  109. def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
  110. if lora_request.lora_int_id in self.list_adapters():
  111. return False
  112. if isinstance(self._cached_dummy_lora, LoRAModel):
  113. dummy_lora = self._cached_dummy_lora.clone(
  114. lora_request.lora_int_id)
  115. else:
  116. dummy_lora = self._adapter_manager.create_dummy_lora(
  117. lora_request.lora_int_id, rank, 1, self.embedding_modules)
  118. if self._cached_dummy_lora is None:
  119. self._cached_dummy_lora = dummy_lora
  120. return self._adapter_manager.add_adapter(dummy_lora)
  121. def pin_adapter(self, adapter_id: int) -> bool:
  122. return self._adapter_manager.pin_adapter(adapter_id)
  123. def set_active_adapters(self, requests: Set[Any],
  124. mapping: Optional[Any]) -> None:
  125. set_active_adapters_worker(requests, mapping, self._apply_adapters,
  126. self._adapter_manager.set_adapter_mapping)
  127. def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
  128. apply_adapters_worker(adapter_requests, self.list_adapters,
  129. self._adapter_manager.adapter_slots,
  130. self.remove_adapter, self.add_adapter)
  131. def add_adapter(self, adapter_request: Any) -> bool:
  132. return add_adapter_worker(adapter_request, self.list_adapters,
  133. self._load_adapter,
  134. self._adapter_manager.add_adapter,
  135. self._adapter_manager.activate_adapter)
  136. def remove_adapter(self, adapter_id: int) -> bool:
  137. return self._adapter_manager.remove_adapter(adapter_id)
  138. def remove_all_adapters(self):
  139. self._adapter_manager.remove_all_adapters()
  140. def list_adapters(self) -> Set[int]:
  141. return list_adapters_worker(self._adapter_manager.list_adapters)
  142. class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
  143. """WorkerLoRAManager that manages LoRA models on the worker side.
  144. Uses an LRU Cache. Every request, the requested LoRAs will be loaded
  145. (unless they are already loaded) and least recently used LoRAs will
  146. be unloaded if the cache is above capacity."""
  147. _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
  148. def create_lora_manager(
  149. self,
  150. model: torch.nn.Module,
  151. ) -> Any:
  152. lora_manager = create_lora_manager(
  153. model,
  154. lora_manager_cls=self._manager_cls,
  155. max_num_seqs=self.max_num_seqs,
  156. vocab_size=self.vocab_size,
  157. lora_config=self.lora_config,
  158. max_num_batched_tokens=self.max_num_batched_tokens,
  159. )
  160. self._adapter_manager = lora_manager
  161. return lora_manager.model
  162. def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
  163. loras_map = {
  164. lora_request.lora_int_id: lora_request
  165. for lora_request in lora_requests if lora_request
  166. }
  167. if len(loras_map) > self._adapter_manager.lora_slots:
  168. raise RuntimeError(
  169. f"Number of requested LoRAs ({len(loras_map)}) is greater "
  170. "than the number of GPU LoRA slots "
  171. f"({self._adapter_manager.lora_slots}).")
  172. for lora in loras_map.values():
  173. self.add_adapter(lora)
  174. def add_adapter(self, lora_request: LoRARequest) -> bool:
  175. if lora_request.lora_int_id not in self.list_adapters():
  176. # Remove before we load the new lora to save memory
  177. if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
  178. assert isinstance(self._adapter_manager,
  179. LRUCacheLoRAModelManager)
  180. self._adapter_manager.remove_oldest_adapter()
  181. lora = self._load_adapter(lora_request)
  182. loaded = self._adapter_manager.add_adapter(lora)
  183. else:
  184. # If the lora is already loaded, just touch it to
  185. # update its position in the caches
  186. loaded = self._adapter_manager.get_adapter(
  187. lora_request.lora_int_id) is not None
  188. self._adapter_manager.activate_adapter(lora_request.lora_int_id)
  189. return loaded