worker_manager.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from contextlib import contextmanager
  2. from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
  3. import torch
  4. from aphrodite.adapter_commons.utils import (add_adapter_worker,
  5. apply_adapters_worker,
  6. list_adapters_worker,
  7. set_active_adapters_worker)
  8. from aphrodite.adapter_commons.worker_manager import AbstractWorkerManager
  9. from aphrodite.common.config import LoRAConfig
  10. from aphrodite.lora.models import (LoRAModel, LoRAModelManager,
  11. LRUCacheLoRAModelManager,
  12. create_lora_manager)
  13. from aphrodite.lora.request import LoRARequest
  14. class WorkerLoRAManager(AbstractWorkerManager):
  15. """WorkerLoRAManager that manages LoRA models on the worker side.
  16. Every request, the requested LoRAs will be loaded (unless they are already
  17. loaded), and every other LoRA will be unloaded."""
  18. _manager_cls: Type[LoRAModelManager] = LoRAModelManager
  19. def __init__(
  20. self,
  21. max_num_seqs: int,
  22. max_num_batched_tokens: int,
  23. vocab_size: int,
  24. lora_config: LoRAConfig,
  25. device: torch.device,
  26. embedding_modules: Dict[str, str],
  27. embedding_padding_modules: List[str],
  28. lora_model_cls: Type[LoRAModel] = LoRAModel,
  29. max_position_embeddings: Optional[int] = None,
  30. ):
  31. self._lora_model_cls = lora_model_cls
  32. self.embedding_modules = embedding_modules
  33. self.embedding_padding_modules = embedding_padding_modules
  34. self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
  35. self.max_num_seqs = max_num_seqs
  36. self.max_num_batched_tokens = max_num_batched_tokens
  37. self.vocab_size = vocab_size
  38. self.lora_config = lora_config
  39. self.max_position_embeddings = max_position_embeddings
  40. super().__init__(device)
  41. # Lazily initialized by create_lora_manager.
  42. self._adapter_manager: LoRAModelManager
  43. @contextmanager
  44. def dummy_lora_cache(self):
  45. """Use this context manager to reuse the dummy lora model
  46. to avoid creating it repeatedly."""
  47. self._cached_dummy_lora = None
  48. yield
  49. self._cached_dummy_lora = False
  50. @property
  51. def is_enabled(self) -> bool:
  52. return True
  53. def create_lora_manager(
  54. self,
  55. model: torch.nn.Module,
  56. ) -> Any:
  57. lora_manager = create_lora_manager(
  58. model,
  59. max_num_seqs=self.max_num_seqs,
  60. max_num_batched_tokens=self.max_num_batched_tokens,
  61. vocab_size=self.vocab_size,
  62. lora_config=self.lora_config,
  63. lora_manager_cls=self._manager_cls,
  64. )
  65. self._adapter_manager = lora_manager
  66. return lora_manager.model
  67. def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
  68. try:
  69. model = self._adapter_manager.model
  70. supported_lora_modules = model.supported_lora_modules
  71. packed_modules_mapping = model.packed_modules_mapping
  72. expected_lora_modules: List[str] = []
  73. for module in supported_lora_modules:
  74. if module in packed_modules_mapping:
  75. expected_lora_modules.extend(
  76. packed_modules_mapping[module])
  77. else:
  78. expected_lora_modules.append(module)
  79. lora = self._lora_model_cls.from_local_checkpoint(
  80. lora_request.lora_local_path,
  81. expected_lora_modules,
  82. max_position_embeddings=self.max_position_embeddings,
  83. lora_model_id=lora_request.lora_int_id,
  84. device="cpu",
  85. dtype=self.lora_config.lora_dtype,
  86. target_embedding_padding=self.vocab_size +
  87. self.lora_config.lora_extra_vocab_size,
  88. embedding_modules=self.embedding_modules,
  89. embedding_padding_modules=self.embedding_padding_modules,
  90. )
  91. except Exception as e:
  92. raise RuntimeError(
  93. f"Loading lora {lora_request.lora_local_path} failed") from e
  94. if lora.rank > self.lora_config.max_lora_rank:
  95. raise ValueError(
  96. f"LoRA rank {lora.rank} is greater than max_lora_rank "
  97. f"{self.lora_config.max_lora_rank}.")
  98. if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
  99. raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
  100. f"is greater than lora_extra_vocab_size "
  101. f"{self.lora_config.lora_extra_vocab_size}.")
  102. return lora
  103. def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
  104. if lora_request.lora_int_id in self.list_adapters():
  105. return False
  106. if isinstance(self._cached_dummy_lora, LoRAModel):
  107. dummy_lora = self._cached_dummy_lora.clone(
  108. lora_request.lora_int_id)
  109. else:
  110. dummy_lora = self._adapter_manager.create_dummy_lora(
  111. lora_request.lora_int_id, rank, 1, self.embedding_modules)
  112. if self._cached_dummy_lora is None:
  113. self._cached_dummy_lora = dummy_lora
  114. return self._adapter_manager.add_adapter(dummy_lora)
  115. def pin_adapter(self, adapter_id: int) -> bool:
  116. return self._adapter_manager.pin_adapter(adapter_id)
  117. def set_active_adapters(self, requests: Set[Any],
  118. mapping: Optional[Any]) -> None:
  119. set_active_adapters_worker(requests, mapping, self._apply_adapters,
  120. self._adapter_manager.set_adapter_mapping)
  121. def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
  122. apply_adapters_worker(adapter_requests, self.list_adapters,
  123. self._adapter_manager.adapter_slots,
  124. self.remove_adapter, self.add_adapter)
  125. def add_adapter(self, adapter_request: Any) -> bool:
  126. return add_adapter_worker(adapter_request, self.list_adapters,
  127. self._load_adapter,
  128. self._adapter_manager.add_adapter,
  129. self._adapter_manager.activate_adapter)
  130. def remove_adapter(self, adapter_id: int) -> bool:
  131. return self._adapter_manager.remove_adapter(adapter_id)
  132. def remove_all_adapters(self):
  133. self._adapter_manager.remove_all_adapters()
  134. def list_adapters(self) -> Set[int]:
  135. return list_adapters_worker(self._adapter_manager.list_adapters)
  136. class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
  137. """WorkerLoRAManager that manages LoRA models on the worker side.
  138. Uses an LRU Cache. Every request, the requested LoRAs will be loaded
  139. (unless they are already loaded) and least recently used LoRAs will
  140. be unloaded if the cache is above capacity."""
  141. _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
  142. def create_lora_manager(
  143. self,
  144. model: torch.nn.Module,
  145. ) -> Any:
  146. lora_manager = create_lora_manager(
  147. model,
  148. lora_manager_cls=self._manager_cls,
  149. max_num_seqs=self.max_num_seqs,
  150. vocab_size=self.vocab_size,
  151. lora_config=self.lora_config,
  152. max_num_batched_tokens=self.max_num_batched_tokens,
  153. )
  154. self._adapter_manager = lora_manager
  155. return lora_manager.model
  156. def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
  157. loras_map = {
  158. lora_request.lora_int_id: lora_request
  159. for lora_request in lora_requests if lora_request
  160. }
  161. if len(loras_map) > self._adapter_manager.lora_slots:
  162. raise RuntimeError(
  163. f"Number of requested LoRAs ({len(loras_map)}) is greater "
  164. "than the number of GPU LoRA slots "
  165. f"({self._adapter_manager.lora_slots}).")
  166. for lora in loras_map.values():
  167. self.add_adapter(lora)
  168. def add_adapter(self, lora_request: LoRARequest) -> bool:
  169. if lora_request.lora_int_id not in self.list_adapters():
  170. # Remove before we load the new lora to save memory
  171. if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
  172. assert isinstance(self._adapter_manager,
  173. LRUCacheLoRAModelManager)
  174. self._adapter_manager.remove_oldest_adapter()
  175. lora = self._load_adapter(lora_request)
  176. loaded = self._adapter_manager.add_adapter(lora)
  177. else:
  178. # If the lora is already loaded, just touch it to
  179. # update its position in the caches
  180. loaded = self._adapter_manager.get_adapter(
  181. lora_request.lora_int_id) is not None
  182. self._adapter_manager.activate_adapter(lora_request.lora_int_id)
  183. return loaded