worker_manager.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import logging
  2. from abc import ABC, abstractmethod, abstractproperty
  3. from typing import Any, Dict, List, Optional, Set, Type
  4. import torch
  5. from aphrodite.lora.models import (LoRAModel, LoRAModelManager,
  6. LRUCacheLoRAModelManager,
  7. create_lora_manager)
  8. from aphrodite.lora.request import LoRARequest
  9. from aphrodite.lora.layers import LoRAMapping
  10. from aphrodite.common.config import LoRAConfig
  11. logger = logging.getLogger(__name__)
  12. class AbstractWorkerLoRAManager(ABC):
  13. """Abstract class for managing LoRA models on the worker side."""
  14. def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
  15. vocab_size: int, lora_config: LoRAConfig,
  16. device: torch.device):
  17. self.max_num_seqs = max_num_seqs
  18. self.max_num_batched_tokens = max_num_batched_tokens
  19. self.vocab_size = vocab_size
  20. self.device = device
  21. self.lora_config = lora_config
  22. @abstractproperty
  23. def is_enabled(self) -> bool:
  24. ...
  25. @abstractmethod
  26. def create_lora_manager(
  27. self,
  28. model: torch.nn.Module,
  29. ) -> Any:
  30. ...
  31. @abstractmethod
  32. def set_active_loras(self, lora_requests: List[LoRARequest],
  33. lora_mapping: LoRAMapping) -> None:
  34. ...
  35. @abstractmethod
  36. def add_lora(self, lora_request: LoRARequest) -> bool:
  37. ...
  38. @abstractmethod
  39. def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
  40. ...
  41. @abstractmethod
  42. def remove_lora(self, lora_id: int) -> bool:
  43. ...
  44. @abstractmethod
  45. def remove_all_loras(self) -> bool:
  46. ...
  47. @abstractmethod
  48. def list_loras(self) -> Set[int]:
  49. ...
  50. # pylint: disable=function-redefined
  51. class WorkerLoRAManager(AbstractWorkerLoRAManager):
  52. """WorkerLoRAManager that manages LoRA models on the worker side.
  53. Every request, the requested LoRAs will be loaded (unless they are already
  54. loaded), and every other LoRA will be unloaded."""
  55. _lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager
  56. def __init__(
  57. self,
  58. max_num_seqs: int,
  59. max_num_batched_tokens: int,
  60. vocab_size: int,
  61. lora_config: LoRAConfig,
  62. device: torch.device,
  63. embedding_modules: Dict[str, str],
  64. embedding_padding_modules: List[str],
  65. lora_model_cls: Type[LoRAModel] = LoRAModel,
  66. ):
  67. self._lora_manager: Optional[LoRAModelManager] = None
  68. self._lora_model_cls = lora_model_cls
  69. self.embedding_modules = embedding_modules
  70. self.embedding_padding_modules = embedding_padding_modules
  71. super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
  72. lora_config, device)
  73. @property
  74. # pylint: disable=invalid-overridden-method
  75. def is_enabled(self) -> bool:
  76. return True
  77. def create_lora_manager(
  78. self,
  79. model: torch.nn.Module,
  80. ) -> Any:
  81. lora_manager = create_lora_manager(
  82. model,
  83. max_num_seqs=self.max_num_seqs,
  84. max_num_batched_tokens=self.max_num_batched_tokens,
  85. vocab_size=self.vocab_size,
  86. lora_config=self.lora_config,
  87. lora_manager_cls=self._lora_manager_cls,
  88. )
  89. self._lora_manager: LoRAModelManager = lora_manager
  90. return lora_manager.model
  91. def set_active_loras(self, lora_requests: List[LoRARequest],
  92. lora_mapping: LoRAMapping) -> None:
  93. self._apply_loras(lora_requests)
  94. self._lora_manager.set_lora_mapping(lora_mapping)
  95. def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
  96. loras_that_exist = self.list_loras()
  97. loras_map = {
  98. lora_request.lora_int_id: lora_request
  99. for lora_request in lora_requests if lora_request
  100. }
  101. if len(loras_map) > self._lora_manager.lora_slots:
  102. raise RuntimeError(
  103. f"Number of requested LoRAs ({len(loras_map)}) is greater "
  104. "than the number of GPU LoRA slots "
  105. f"({self._lora_manager.lora_slots}).")
  106. new_loras = set(loras_map)
  107. loras_to_add = new_loras - loras_that_exist
  108. loras_to_remove = loras_that_exist - new_loras
  109. for lora_id in loras_to_remove:
  110. self.remove_lora(lora_id)
  111. for lora_id in loras_to_add:
  112. self.add_lora(loras_map[lora_id])
  113. def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
  114. try:
  115. lora = self._lora_model_cls.from_local_checkpoint(
  116. lora_request.lora_local_path,
  117. lora_model_id=lora_request.lora_int_id,
  118. device="cpu",
  119. dtype=self.lora_config.lora_dtype,
  120. target_embedding_padding=self.vocab_size +
  121. self.lora_config.lora_extra_vocab_size,
  122. embedding_modules=self.embedding_modules,
  123. embedding_padding_modules=self.embedding_padding_modules,
  124. )
  125. except Exception as e:
  126. raise RuntimeError(
  127. f"Loading lora {lora_request.lora_local_path} failed") from e
  128. if lora.rank > self.lora_config.max_lora_rank:
  129. raise ValueError(
  130. f"LoRA rank {lora.rank} is greater than max_lora_rank "
  131. f"{self.lora_config.max_lora_rank}.")
  132. if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
  133. raise ValueError(
  134. f"LoRA added vocab size {lora.extra_vocab_size} is "
  135. "greater than lora_extra_vocab_size "
  136. f"{self.lora_config.lora_extra_vocab_size}.")
  137. return lora
  138. def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
  139. if lora_request.lora_int_id in self.list_loras():
  140. return False
  141. return self._lora_manager.add_lora(
  142. self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
  143. rank, self.embedding_modules))
  144. def add_lora(self, lora_request: LoRARequest) -> bool:
  145. if lora_request.lora_int_id in self.list_loras():
  146. return False
  147. lora = self._load_lora(lora_request)
  148. loaded = self._lora_manager.add_lora(lora)
  149. self._lora_manager.activate_lora(lora.id)
  150. return loaded
  151. def remove_lora(self, lora_id: int) -> bool:
  152. return self._lora_manager.remove_lora(lora_id)
  153. def remove_all_loras(self) -> bool:
  154. self._lora_manager.remove_all_loras()
  155. def list_loras(self) -> Set[int]:
  156. return set(self._lora_manager.list_loras())
  157. class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
  158. """WorkerLoRAManager that manages LoRA models on the worker side.
  159. Uses an LRU Cache. Every request, the requested LoRAs will be loaded
  160. (unless they are already loaded) and least recently used LoRAs will
  161. be unloaded if the cache is above capacity."""
  162. _lora_manager_cls: Type[
  163. LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
  164. def create_lora_manager(
  165. self,
  166. model: torch.nn.Module,
  167. ) -> Any:
  168. lora_manager = create_lora_manager(
  169. model,
  170. lora_manager_cls=self._lora_manager_cls,
  171. max_num_seqs=self.max_num_seqs,
  172. vocab_size=self.vocab_size,
  173. lora_config=self.lora_config,
  174. max_num_batched_tokens=self.max_num_batched_tokens,
  175. )
  176. self._lora_manager: LRUCacheLoRAModelManager = lora_manager
  177. return lora_manager.model
  178. def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
  179. loras_map = {
  180. lora_request.lora_int_id: lora_request
  181. for lora_request in lora_requests if lora_request
  182. }
  183. if len(loras_map) > self._lora_manager.lora_slots:
  184. raise RuntimeError(
  185. f"Number of requested LoRAs ({len(loras_map)}) is greater "
  186. "than the number of GPU LoRA slots "
  187. f"({self._lora_manager.lora_slots}).")
  188. for lora in loras_map.values():
  189. self.add_lora(lora)
  190. def add_lora(self, lora_request: LoRARequest) -> bool:
  191. if lora_request.lora_int_id not in self.list_loras():
  192. # Remove before we load the new lora to save memory
  193. if len(self._lora_manager) + 1 > self._lora_manager.capacity:
  194. self._lora_manager.remove_oldest_lora()
  195. lora = self._load_lora(lora_request)
  196. loaded = self._lora_manager.add_lora(lora)
  197. else:
  198. # If the lora is already loaded, just touch it to
  199. # update its position in the caches
  200. loaded = self._lora_manager.get_lora(lora_request.lora_int_id)
  201. self._lora_manager.activate_lora(lora_request.lora_int_id)
  202. return loaded