models.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. import copy
  2. import json
  3. import math
  4. import os
  5. import re
  6. from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type
  7. import safetensors.torch
  8. import torch
  9. from loguru import logger
  10. from torch import nn
  11. from aphrodite.common.config import LoRAConfig
  12. from aphrodite.common.utils import LRUCache, is_pin_memory_available
  13. from aphrodite.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
  14. from_layer_logits_processor)
  15. from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
  16. from aphrodite.lora.utils import parse_fine_tuned_lora_name, replace_submodule
  17. _GLOBAL_LORA_ID = 0
  18. def convert_mapping(
  19. mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
  20. max_loras: int, vocab_size: int, extra_vocab_size: int
  21. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
  22. """Converts LoRAMapping to index tensors.
  23. Args:
  24. mapping: LoRAMapping mapping rows in a batch to LoRA ids.
  25. lora_index_to_id: List mapping LoRA ids to LoRA indices.
  26. max_loras: Maximum number of LoRAs.
  27. vocab_size: Model vocab size.
  28. extra_vocab_size: Extra vocab size each LoRA can have.
  29. Returns:
  30. A tuple of tensors:
  31. base_indices: Tensor of shape [batch_size] mapping batch rows to
  32. LoRA indices.
  33. sampler_indices: Tensor of shape [batch_size] mapping requests to
  34. LoRA indices for sampler. For generation, this will be the
  35. same as base_indicies. For prefill, this will map requests
  36. to LoRA indices.
  37. sampler_indices_padded: Tensor of shape [batch_size] mapping
  38. requests to LoRA indices for sampler with padding.
  39. Same as sampler_indicies, but -1 is replaced with
  40. max_loras.
  41. embeddings_indices: Tensor of shape [2, batch_size] mapping
  42. requests to embedding indices. First row is for embeddings
  43. added by the LoRAs, second row is for the LoRA.lora_a
  44. embeddings.
  45. indices_len: List of lengths of the above tensors.
  46. """
  47. indices = list(mapping.index_mapping).copy()
  48. embedding_indices = indices.copy()
  49. lora_indices = indices.copy()
  50. prompt_mapping = [
  51. lora_index_to_id.index(x) if x > 0 else -1
  52. for x in mapping.prompt_mapping
  53. ]
  54. lora_idx = None
  55. for i in range(len(indices)):
  56. # TODO index can be slow. optimize
  57. lora_idx = (lora_index_to_id.index(indices[i])
  58. if indices[i] > 0 else -1)
  59. embedding_indices[i] = lora_idx if indices[i] > 0 else 0
  60. indices[i] = i
  61. lora_indices[i] = lora_idx
  62. indices = torch.tensor([indices, lora_indices, embedding_indices],
  63. dtype=torch.long,
  64. device="cuda")
  65. prompt_mapping = torch.tensor(prompt_mapping,
  66. device="cuda",
  67. dtype=torch.long)
  68. embeddings_indices = torch.stack([
  69. indices[2] * extra_vocab_size,
  70. indices[2] * (vocab_size + extra_vocab_size)
  71. ])
  72. embeddings_indices[embeddings_indices == -1] = max_loras - 1
  73. base_indices = indices[1]
  74. sampler_indices = prompt_mapping
  75. sampler_indices_padded = sampler_indices.clone()
  76. sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
  77. sampler_indices_padded = (
  78. torch.arange(
  79. 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
  80. (sampler_indices_padded * len(sampler_indices_padded)))
  81. indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
  82. sampler_indices_padded.shape[-1],
  83. embeddings_indices.shape[-1])
  84. return (base_indices, sampler_indices, sampler_indices_padded,
  85. embeddings_indices, indices_len)
  86. def get_lora_id():
  87. global _GLOBAL_LORA_ID
  88. _GLOBAL_LORA_ID += 1
  89. return _GLOBAL_LORA_ID
  90. class LoRAModel:
  91. """A LoRA fine-tuned model."""
  92. def __init__(
  93. self,
  94. lora_model_id: int,
  95. rank: int,
  96. loras: Dict[str, LoRALayerWeights],
  97. ) -> None:
  98. self.id = lora_model_id
  99. assert (lora_model_id >
  100. 0), f"a valid lora id should be greater than 0, got {self.id}"
  101. self.rank = rank
  102. self.loras: Dict[str, LoRALayerWeights] = loras
  103. @property
  104. def extra_vocab_size(self) -> int:
  105. return max(lora.extra_vocab_size
  106. for lora in self.loras.values()) if self.loras else 0
  107. def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
  108. """Get LoRA for a given module by name"""
  109. return self.loras.get(module_name, None)
  110. # (yard1): TODO see if we can derive target_embedding_padding automatically
  111. @classmethod
  112. def from_lora_tensors(
  113. cls,
  114. lora_model_id: int,
  115. rank: int,
  116. lora_alpha: int,
  117. tensors: Dict[str, torch.Tensor],
  118. device: str = "cuda",
  119. dtype: Optional[torch.dtype] = None,
  120. embeddings: Optional[Dict[str, torch.Tensor]] = None,
  121. target_embedding_padding: Optional[int] = None,
  122. embedding_modules: Optional[Dict[str, str]] = None,
  123. embedding_padding_modules: Optional[List[str]] = None,
  124. ) -> "LoRAModel":
  125. """Create a LoRAModel from a dictionary of tensors."""
  126. pin_memory = str(device) == "cpu" and is_pin_memory_available()
  127. loras: Dict[str, LoRALayerWeights] = {}
  128. for tensor_name, tensor in tensors.items():
  129. module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
  130. if module_name not in loras:
  131. lora_embeddings_tensor = None
  132. if embeddings:
  133. embeddings_module = next(
  134. (k for k in embedding_modules if k in module_name),
  135. None)
  136. if embeddings_module:
  137. lora_embeddings_tensor = embeddings[
  138. embedding_modules[embeddings_module]].to(
  139. device=device, dtype=dtype)
  140. if pin_memory:
  141. lora_embeddings_tensor = (
  142. lora_embeddings_tensor.pin_memory())
  143. loras[module_name] = LoRALayerWeights(module_name, rank,
  144. lora_alpha, None, None,
  145. lora_embeddings_tensor)
  146. if is_lora_a:
  147. loras[module_name].lora_a = tensor.to(device=device,
  148. dtype=dtype).t()
  149. if pin_memory:
  150. loras[module_name].lora_a = loras[
  151. module_name].lora_a.pin_memory()
  152. else:
  153. loras[module_name].lora_b = tensor.to(device=device,
  154. dtype=dtype).t()
  155. if any(name in module_name
  156. for name in embedding_padding_modules
  157. ) and target_embedding_padding is not None:
  158. lora_b = loras[module_name].lora_b
  159. assert target_embedding_padding >= lora_b.shape[1]
  160. addition = target_embedding_padding - lora_b.shape[1]
  161. loras[module_name].lora_b = torch.nn.functional.pad(
  162. lora_b, (0, addition))
  163. if pin_memory:
  164. loras[module_name].lora_b = loras[
  165. module_name].lora_b.pin_memory()
  166. # Filter out LoRALayerWeights instances where lora_a or lora_b is None
  167. loras = {
  168. k: v
  169. for k, v in loras.items()
  170. if v.lora_a is not None and v.lora_b is not None
  171. }
  172. for lora in loras.values():
  173. lora.optimize()
  174. return cls(lora_model_id, rank, loras)
  175. @classmethod
  176. def from_local_checkpoint(
  177. cls,
  178. lora_dir: str,
  179. expected_lora_modules: List[str],
  180. lora_model_id: Optional[int] = None,
  181. device: str = "cuda",
  182. dtype: Optional[torch.dtype] = None,
  183. target_embedding_padding: Optional[int] = None,
  184. embedding_modules: Optional[Dict[str, str]] = None,
  185. embedding_padding_modules: Optional[List[str]] = None,
  186. ) -> "LoRAModel":
  187. """Create a LoRAModel from a local checkpoint."""
  188. lora_config_path = os.path.join(lora_dir, "adapter_config.json")
  189. lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
  190. lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
  191. new_embeddings_tensor_path = os.path.join(
  192. lora_dir, "new_embeddings.safetensors")
  193. new_embeddings_bin_file_path = os.path.join(lora_dir,
  194. "new_embeddings.bin")
  195. with open(lora_config_path) as f:
  196. config = json.load(f)
  197. target_modules = config["target_modules"]
  198. unexpected_modules = []
  199. for module in target_modules:
  200. # Compatible with more modules, such as:layers.11.self_attn.k_proj
  201. part_name = module.split(".")[-1]
  202. if part_name not in expected_lora_modules:
  203. unexpected_modules.append(module)
  204. # loaded lora's target modules must be a subset of expected_lora_modules
  205. if unexpected_modules:
  206. raise ValueError(
  207. f"While loading {lora_dir}, expected"
  208. f" target modules in {expected_lora_modules}"
  209. f" but received {unexpected_modules}."
  210. f" Please verify that the loaded LoRA module is correct")
  211. if os.path.isfile(lora_tensor_path):
  212. tensors = safetensors.torch.load_file(lora_tensor_path)
  213. elif os.path.isfile(lora_bin_file_path):
  214. tensors = torch.load(lora_bin_file_path)
  215. else:
  216. raise ValueError(f"{lora_dir} doesn't contain tensors")
  217. embeddings = None
  218. if os.path.isfile(new_embeddings_tensor_path):
  219. embeddings = safetensors.torch.load_file(
  220. new_embeddings_tensor_path)
  221. elif os.path.isfile(new_embeddings_bin_file_path):
  222. embeddings = torch.load(new_embeddings_bin_file_path)
  223. rank = config["r"]
  224. lora_alpha = config["lora_alpha"]
  225. return cls.from_lora_tensors(
  226. lora_model_id=get_lora_id()
  227. if lora_model_id is None else lora_model_id,
  228. rank=rank,
  229. lora_alpha=lora_alpha,
  230. tensors=tensors,
  231. device=device,
  232. dtype=dtype,
  233. embeddings=embeddings,
  234. target_embedding_padding=target_embedding_padding,
  235. embedding_modules=embedding_modules,
  236. embedding_padding_modules=embedding_padding_modules,
  237. )
  238. class LoRAModelManager:
  239. """A manager that manages multiple LoRA-fine-tuned models."""
  240. def __init__(
  241. self,
  242. model: nn.Module,
  243. max_num_seqs: int,
  244. max_num_batched_tokens: int,
  245. vocab_size: int,
  246. lora_config: LoRAConfig,
  247. ):
  248. """Create a LoRAModelManager and adapter for a given model.
  249. Args:
  250. model: the model to be adapted.
  251. max_num_seqs: the maximum number of sequences model can run in a
  252. single batch.
  253. max_num_batched_tokens: the maximum number of tokens model can run
  254. in a single batch.
  255. vocab_size: the vocab size of the model.
  256. lora_config: the LoRA configuration.
  257. """
  258. self.lora_config = lora_config
  259. self.max_num_seqs = max_num_seqs
  260. assert self.capacity >= self.lora_slots
  261. self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
  262. self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
  263. self.vocab_size = vocab_size
  264. self.base_indices = torch.empty(self.max_num_batched_tokens,
  265. dtype=torch.long,
  266. device="cuda")
  267. self.sampler_indices = torch.empty(self.max_num_batched_tokens,
  268. dtype=torch.long,
  269. device="cuda")
  270. self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
  271. dtype=torch.long,
  272. device="cuda")
  273. self.embeddings_indices = torch.empty(2,
  274. self.max_num_batched_tokens,
  275. dtype=torch.long,
  276. device="cuda")
  277. self.offsets = []
  278. # 4 is the number of indicies tensors defined above
  279. # base_indices, sampler_indices, sampler_indices_padded,
  280. # embeddings_indices
  281. self.indices_len = [None] * 4
  282. self.model: nn.Module = model
  283. if hasattr(self.model, "supported_lora_modules"):
  284. self.supported_lora_modules = copy.deepcopy(
  285. self.model.supported_lora_modules)
  286. self.packed_modules_mapping = copy.deepcopy(
  287. self.model.packed_modules_mapping)
  288. self.packed_modules: Dict[str, List[str]] = {}
  289. self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
  290. self._registered_loras: Dict[int, LoRAModel] = {}
  291. # Dict instead of a Set for compatibility with LRUCache.
  292. self._active_loras: Dict[int, None] = {}
  293. self._last_mapping = None
  294. self._create_lora_modules()
  295. self.model.lora_manager = self
  296. @property
  297. def capacity(self) -> int:
  298. return self.lora_config.max_cpu_loras
  299. @property
  300. def lora_slots(self) -> int:
  301. return self.lora_config.max_loras
  302. def __len__(self) -> int:
  303. return len(self._registered_loras)
  304. def activate_lora(
  305. self,
  306. lora_id: int,
  307. ) -> bool:
  308. """Move LoRA into a GPU buffer to be used in the forward pass."""
  309. if lora_id in self._active_loras:
  310. return False
  311. first_free_slot = next(
  312. ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
  313. if lora_id is None), None)
  314. if first_free_slot is None:
  315. raise ValueError("No free lora slots")
  316. index, _ = first_free_slot
  317. self._active_loras[lora_id] = None
  318. lora_model = self._registered_loras[lora_id]
  319. logger.debug(
  320. f"Activating LoRA. int id: {lora_model.id}, slot index: {index}")
  321. self.lora_index_to_id[index] = lora_model.id
  322. for module_name, module in self.modules.items():
  323. module_lora = lora_model.get_lora(module_name)
  324. if module_lora:
  325. module_lora.optimize()
  326. module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
  327. module_lora.embeddings_tensor)
  328. else:
  329. module.reset_lora(index)
  330. return True
  331. def _deactivate_lora(self, lora_id: int):
  332. try:
  333. index = self.lora_index_to_id.index(lora_id)
  334. self.lora_index_to_id[index] = None
  335. except ValueError:
  336. pass
  337. def deactivate_lora(self, lora_id: int) -> bool:
  338. """Remove a LoRA from a GPU buffer."""
  339. if lora_id in self._active_loras:
  340. self._deactivate_lora(lora_id)
  341. self._active_loras.pop(lora_id)
  342. return True
  343. return False
  344. def _add_lora(self, lora: LoRAModel) -> bool:
  345. self._create_merged_loras_inplace(lora)
  346. self._registered_loras[lora.id] = lora
  347. def add_lora(self, lora: LoRAModel) -> bool:
  348. """Add a LoRAModel to the manager CPU cache."""
  349. if lora.id not in self._registered_loras:
  350. if len(self._registered_loras) >= self.capacity:
  351. raise RuntimeError("No free LoRA slots.")
  352. self._add_lora(lora)
  353. return True
  354. return False
  355. def remove_lora(self, lora_id: int) -> bool:
  356. """Remove a LoRAModel from the manager CPU cache."""
  357. # TODO: should we check active lora?
  358. self.deactivate_lora(lora_id)
  359. return bool(self._registered_loras.pop(lora_id, None))
  360. # TODO see if this can be vectorized
  361. def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
  362. (base_indices, sampler_indices, sampler_indices_padded,
  363. embeddings_indices,
  364. indices_len) = convert_mapping(mapping, self.lora_index_to_id,
  365. self.lora_slots + 1, self.vocab_size,
  366. self.lora_config.lora_extra_vocab_size)
  367. self.base_indices[:base_indices.shape[0]].copy_(base_indices)
  368. self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
  369. self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
  370. sampler_indices_padded)
  371. self.embeddings_indices[:embeddings_indices.
  372. shape[0], :embeddings_indices.shape[1]].copy_(
  373. embeddings_indices)
  374. # Maintain the reference
  375. self.indices_len[:] = indices_len
  376. def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
  377. if self._last_mapping != lora_mapping:
  378. self._set_lora_mapping(lora_mapping)
  379. self._last_mapping = lora_mapping
  380. def list_loras(self) -> Dict[int, LoRAModel]:
  381. """List all registered LoRAModels."""
  382. return dict(self._registered_loras)
  383. def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
  384. return self._registered_loras.get(lora_id, None)
  385. def remove_all_loras(self) -> bool:
  386. """Remove all LoRAModels from the manager."""
  387. self._registered_loras.clear()
  388. self.lora_index_to_id = [None] * self.lora_slots
  389. self._active_loras.clear()
  390. def _create_lora_modules(self):
  391. for module_name, module in self.model.named_modules():
  392. if not self._match_target_modules(module_name):
  393. continue
  394. parts = module_name.split(".")[-1]
  395. packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
  396. new_module = replace_submodule(
  397. self.model, module_name,
  398. from_layer(module, self.lora_slots, self.lora_config,
  399. packed_moduled_lst, self.model.config))
  400. # (yard1): TODO make this more robust
  401. if "lm_head" in module_name:
  402. logits_processor_module = self.model.get_submodule(
  403. "logits_processor")
  404. new_module = replace_submodule(
  405. self.model, "logits_processor",
  406. from_layer_logits_processor(logits_processor_module,
  407. module, self.lora_slots,
  408. self.lora_config,
  409. self.model.config))
  410. self.register_module(module_name, new_module)
  411. self._register_packed_modules(module_name)
  412. new_module.set_mapping(self.base_indices, self.sampler_indices,
  413. self.sampler_indices_padded,
  414. self.embeddings_indices, self.indices_len)
  415. def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
  416. assert isinstance(module, BaseLayerWithLoRA)
  417. self.modules[module_name] = module
  418. def create_dummy_lora(
  419. self,
  420. lora_id: int,
  421. rank: int,
  422. embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
  423. """Create zero-initialized LoRAModel for warmup."""
  424. model = LoRAModel(lora_id, rank, {})
  425. for module_name, module in self.model.named_modules():
  426. if not self._match_target_modules(module_name) or not isinstance(
  427. module, BaseLayerWithLoRA):
  428. continue
  429. parts = module_name.split(".")
  430. if module_name not in self.packed_modules:
  431. if parts[-1] in embedding_modules:
  432. input_dim = (module.base_layer.org_vocab_size +
  433. self.lora_config.lora_extra_vocab_size if
  434. hasattr(module.base_layer, "org_vocab_size")
  435. else module.base_layer.weight.shape[1])
  436. output_dim = module.base_layer.embedding_dim if hasattr(
  437. module.base_layer,
  438. "embedding_dim") else module.base_layer.weight.shape[0]
  439. embeddings_tensor_dim = (module.base_layer.embedding_dim if
  440. hasattr(module.base_layer,
  441. "embedding_dim") else
  442. module.base_layer.weight.shape[1])
  443. lora = LoRALayerWeights.create_dummy_lora_weights(
  444. module_name,
  445. input_dim,
  446. output_dim,
  447. rank,
  448. module.lora_a_stacked.dtype,
  449. "cpu",
  450. embeddings_tensor_dim=embeddings_tensor_dim)
  451. else:
  452. lora = LoRALayerWeights.create_dummy_lora_weights(
  453. module_name,
  454. module.lora_a_stacked.shape[-1],
  455. module.lora_b_stacked.shape[-2],
  456. rank,
  457. module.lora_a_stacked.dtype,
  458. "cpu",
  459. )
  460. lora.optimize()
  461. else:
  462. parts = module_name.split(".")
  463. replacements = self.packed_modules_mapping[parts[-1]]
  464. subloras = []
  465. for i, r in enumerate(replacements):
  466. lora = LoRALayerWeights.create_dummy_lora_weights(
  467. module_name + "." + r,
  468. module.lora_a_stacked[i].shape[-1],
  469. module.lora_b_stacked[i].shape[-2],
  470. rank,
  471. module.lora_a_stacked[i].dtype,
  472. "cpu",
  473. )
  474. lora.optimize()
  475. subloras.append(lora)
  476. lora = PackedLoRALayerWeights.pack(subloras)
  477. model.loras[module_name] = lora
  478. return model
  479. def _match_target_modules(self, module_name: str):
  480. return any(
  481. re.match(
  482. r".*\.{target_module}$".format(target_module=target_module),
  483. module_name) or target_module == module_name
  484. for target_module in self.supported_lora_modules)
  485. def _register_packed_modules(self, module_full_name: str) -> None:
  486. parts = module_full_name.split(".")
  487. module_name = parts[-1]
  488. replacements = self.packed_modules_mapping.get(module_name, [])
  489. # When replacements is less than or equal to 1, it indicates that this
  490. # module is not a packed module.
  491. if len(replacements) <= 1:
  492. return
  493. prefix = ".".join(parts[:-1])
  494. self.packed_modules[module_full_name] = [
  495. prefix + "." + r if prefix else r for r in replacements
  496. ]
  497. def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
  498. for module_name, new_module_names in self.packed_modules.items():
  499. replacement_loras = []
  500. has_replacement = False
  501. for r in new_module_names:
  502. lora = lora_model.get_lora(r)
  503. replacement_loras.append(lora)
  504. if lora:
  505. has_replacement = True
  506. if not has_replacement:
  507. continue
  508. for i in range(len(replacement_loras)):
  509. if replacement_loras[i]:
  510. continue
  511. replacement_loras[i] = None
  512. lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
  513. replacement_loras)
  514. class LoRALRUCache(LRUCache[LoRAModel]):
  515. def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
  516. None]):
  517. super().__init__(capacity)
  518. self.deactivate_lora_fn = deactivate_lora_fn
  519. def _on_remove(self, key: Hashable, value: LoRAModel):
  520. logger.debug(f"Removing LoRA. int id: {key}")
  521. self.deactivate_lora_fn(key)
  522. return super()._on_remove(key, value)
  523. class LRUCacheLoRAModelManager(LoRAModelManager):
  524. """A model manager that manages multiple LoRAs with LRU cache."""
  525. def __init__(
  526. self,
  527. model: nn.Module,
  528. max_num_seqs: int,
  529. max_num_batched_tokens: int,
  530. vocab_size: int,
  531. lora_config: LoRAConfig,
  532. ):
  533. super().__init__(model, max_num_seqs, max_num_batched_tokens,
  534. vocab_size, lora_config)
  535. self._registered_loras: LoRALRUCache = LoRALRUCache(
  536. self.capacity, self.deactivate_lora)
  537. self._active_loras: LoRALRUCache = LoRALRUCache(
  538. self.lora_slots, self._deactivate_lora)
  539. def list_loras(self) -> Dict[int, LoRAModel]:
  540. """List all registered LoRAModels."""
  541. return dict(self._registered_loras.cache)
  542. def add_lora(self, lora: LoRAModel) -> bool:
  543. """Add a LoRAModel to the manager."""
  544. if lora.id not in self._registered_loras:
  545. self._add_lora(lora)
  546. was_added = True
  547. else:
  548. # We always touch to update the LRU cache order
  549. self._registered_loras.touch(lora.id)
  550. was_added = False
  551. return was_added
  552. def activate_lora(
  553. self,
  554. lora_id: int,
  555. ) -> bool:
  556. if lora_id not in self._active_loras and len(
  557. self._active_loras) >= self.lora_slots:
  558. self._active_loras.remove_oldest()
  559. result = super().activate_lora(lora_id)
  560. # We always touch to update the LRU cache order
  561. self._active_loras.touch(lora_id)
  562. return result
  563. def remove_oldest_lora(self) -> bool:
  564. if len(self._registered_loras) > 0:
  565. self._registered_loras.remove_oldest()
  566. return True
  567. return False
  568. def create_lora_manager(
  569. model: nn.Module,
  570. max_num_seqs: int,
  571. max_num_batched_tokens: int,
  572. vocab_size: int,
  573. lora_config: LoRAConfig,
  574. lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
  575. **kwargs) -> LoRAModelManager:
  576. """Create a LoRA adapter for a given model."""
  577. if not hasattr(model, "supported_lora_modules"):
  578. raise ValueError(f"Model {type(model)} is not supported for LoRA.")
  579. lora_manager = lora_manager_cls(
  580. model=model,
  581. max_num_seqs=max_num_seqs,
  582. max_num_batched_tokens=max_num_batched_tokens,
  583. vocab_size=vocab_size,
  584. lora_config=lora_config,
  585. **kwargs)
  586. return lora_manager