models.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. import copy
  2. import json
  3. import logging
  4. import math
  5. import os
  6. import re
  7. from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type)
  8. import safetensors.torch
  9. import torch
  10. from torch import nn
  11. from aphrodite.common.config import LoRAConfig
  12. from aphrodite.common.utils import LRUCache, is_pin_memory_available
  13. from aphrodite.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
  14. from_layer_logits_processor)
  15. from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
  16. from aphrodite.lora.utils import parse_fine_tuned_lora_name, replace_submodule
  17. logger = logging.getLogger(__name__)
  18. _GLOBAL_LORA_ID = 0
  19. def convert_mapping(
  20. mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
  21. max_loras: int, vocab_size: int, extra_vocab_size: int
  22. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
  23. """Converts LoRAMapping to index tensors.
  24. Args:
  25. mapping: LoRAMapping mapping rows in a batch to LoRA ids.
  26. lora_index_to_id: List mapping LoRA ids to LoRA indices.
  27. max_loras: Maximum number of LoRAs.
  28. vocab_size: Model vocab size.
  29. extra_vocab_size: Extra vocab size each LoRA can have.
  30. Returns:
  31. A tuple of tensors:
  32. base_indices: Tensor of shape [batch_size] mapping batch rows to
  33. LoRA indices.
  34. sampler_indices: Tensor of shape [batch_size] mapping requests to
  35. LoRA indices for sampler. For generation, this will be the
  36. same as base_indicies. For prefill, this will map requests
  37. to LoRA indices.
  38. sampler_indices_padded: Tensor of shape [batch_size] mapping
  39. requests to LoRA indices for sampler with padding.
  40. Same as sampler_indicies, but -1 is replaced with
  41. max_loras.
  42. embeddings_indices: Tensor of shape [2, batch_size] mapping
  43. requests to embedding indices. First row is for embeddings
  44. added by the LoRAs, second row is for the LoRA.lora_a
  45. embeddings.
  46. indices_len: List of lengths of the above tensors.
  47. """
  48. indices = list(mapping.index_mapping).copy()
  49. embedding_indices = indices.copy()
  50. lora_indices = indices.copy()
  51. prompt_mapping = [
  52. lora_index_to_id.index(x) if x > 0 else -1
  53. for x in mapping.prompt_mapping
  54. ]
  55. lora_idx = None
  56. for i in range(len(indices)):
  57. # TODO index can be slow. optimize
  58. lora_idx = (lora_index_to_id.index(indices[i])
  59. if indices[i] > 0 else -1)
  60. embedding_indices[i] = lora_idx if indices[i] > 0 else 0
  61. indices[i] = i
  62. lora_indices[i] = lora_idx
  63. indices = torch.tensor([indices, lora_indices, embedding_indices],
  64. dtype=torch.long,
  65. device="cuda")
  66. prompt_mapping = torch.tensor(prompt_mapping,
  67. device="cuda",
  68. dtype=torch.long)
  69. embeddings_indices = torch.stack([
  70. indices[2] * extra_vocab_size,
  71. indices[2] * (vocab_size + extra_vocab_size)
  72. ])
  73. embeddings_indices[embeddings_indices == -1] = max_loras - 1
  74. base_indices = indices[1]
  75. sampler_indices = prompt_mapping
  76. sampler_indices_padded = sampler_indices.clone()
  77. sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
  78. sampler_indices_padded = (
  79. torch.arange(
  80. 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
  81. (sampler_indices_padded * len(sampler_indices_padded)))
  82. indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
  83. sampler_indices_padded.shape[-1],
  84. embeddings_indices.shape[-1])
  85. return (base_indices, sampler_indices, sampler_indices_padded,
  86. embeddings_indices, indices_len)
  87. def get_lora_id():
  88. global _GLOBAL_LORA_ID
  89. _GLOBAL_LORA_ID += 1
  90. return _GLOBAL_LORA_ID
  91. class LoRAModel:
  92. """A LoRA fine-tuned model."""
  93. def __init__(
  94. self,
  95. lora_model_id: int,
  96. rank: int,
  97. loras: Dict[str, LoRALayerWeights],
  98. ) -> None:
  99. self.id = lora_model_id
  100. assert (lora_model_id >
  101. 0), f"a valid lora id should be greater than 0, got {self.id}"
  102. self.rank = rank
  103. self.loras: Dict[str, LoRALayerWeights] = loras
  104. @property
  105. def extra_vocab_size(self) -> int:
  106. return max(lora.extra_vocab_size
  107. for lora in self.loras.values()) if self.loras else 0
  108. def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
  109. """Get LoRA for a given module by name"""
  110. return self.loras.get(module_name, None)
  111. # (yard1): TODO see if we can derive target_embedding_padding automatically
  112. @classmethod
  113. def from_lora_tensors(
  114. cls,
  115. lora_model_id: int,
  116. rank: int,
  117. lora_alpha: int,
  118. tensors: Dict[str, torch.Tensor],
  119. device: str = "cuda",
  120. dtype: Optional[torch.dtype] = None,
  121. embeddings: Optional[Dict[str, torch.Tensor]] = None,
  122. target_embedding_padding: Optional[int] = None,
  123. embedding_modules: Optional[Dict[str, str]] = None,
  124. embedding_padding_modules: Optional[List[str]] = None,
  125. ) -> "LoRAModel":
  126. """Create a LoRAModel from a dictionary of tensors."""
  127. pin_memory = str(device) == "cpu" and not is_pin_memory_available()
  128. loras: Dict[str, LoRALayerWeights] = {}
  129. for tensor_name, tensor in tensors.items():
  130. result = parse_fine_tuned_lora_name(tensor_name)
  131. if result is not None:
  132. module_name, is_lora_a = result
  133. if module_name not in loras:
  134. lora_embeddings_tensor = None
  135. if embeddings:
  136. embeddings_module = next(
  137. (k for k in embedding_modules if k in module_name),
  138. None)
  139. if embeddings_module:
  140. lora_embeddings_tensor = embeddings[
  141. embedding_modules[embeddings_module]].to(
  142. device=device, dtype=dtype)
  143. if pin_memory:
  144. lora_embeddings_tensor = (
  145. lora_embeddings_tensor.pin_memory())
  146. loras[module_name] = LoRALayerWeights(
  147. module_name, rank, lora_alpha, None, None,
  148. lora_embeddings_tensor)
  149. if is_lora_a:
  150. loras[module_name].lora_a = tensor.to(device=device,
  151. dtype=dtype).t()
  152. if pin_memory:
  153. loras[module_name].lora_a = loras[
  154. module_name].lora_a.pin_memory()
  155. else:
  156. loras[module_name].lora_b = tensor.to(device=device,
  157. dtype=dtype).t()
  158. if any(name in module_name
  159. for name in embedding_padding_modules
  160. ) and target_embedding_padding is not None:
  161. lora_b = loras[module_name].lora_b
  162. assert target_embedding_padding >= lora_b.shape[1]
  163. addition = target_embedding_padding - lora_b.shape[1]
  164. loras[module_name].lora_b = torch.nn.functional.pad(
  165. lora_b, (0, addition))
  166. if pin_memory:
  167. loras[module_name].lora_b = loras[
  168. module_name].lora_b.pin_memory()
  169. for lora in loras.values():
  170. lora.optimize()
  171. return cls(lora_model_id, rank, loras)
  172. @classmethod
  173. def from_local_checkpoint(
  174. cls,
  175. lora_dir: str,
  176. lora_model_id: Optional[int] = None,
  177. device: str = "cuda",
  178. dtype: Optional[torch.dtype] = None,
  179. target_embedding_padding: Optional[int] = None,
  180. embedding_modules: Optional[Dict[str, str]] = None,
  181. embedding_padding_modules: Optional[List[str]] = None,
  182. ) -> "LoRAModel":
  183. """Create a LoRAModel from a local checkpoint."""
  184. lora_config_path = os.path.join(lora_dir, "adapter_config.json")
  185. lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
  186. lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
  187. new_embeddings_tensor_path = os.path.join(
  188. lora_dir, "new_embeddings.safetensors")
  189. new_embeddings_bin_file_path = os.path.join(lora_dir,
  190. "new_embeddings.bin")
  191. if os.path.isfile(lora_tensor_path):
  192. tensors = safetensors.torch.load_file(lora_tensor_path)
  193. elif os.path.isfile(lora_bin_file_path):
  194. tensors = torch.load(lora_bin_file_path)
  195. else:
  196. raise ValueError(f"{lora_dir} doesn't contain tensors")
  197. embeddings = None
  198. if os.path.isfile(new_embeddings_tensor_path):
  199. embeddings = safetensors.torch.load_file(
  200. new_embeddings_tensor_path)
  201. elif os.path.isfile(new_embeddings_bin_file_path):
  202. embeddings = torch.load(new_embeddings_bin_file_path)
  203. with open(lora_config_path) as f:
  204. config = json.load(f)
  205. rank = config["r"]
  206. lora_alpha = config["lora_alpha"]
  207. return cls.from_lora_tensors(
  208. lora_model_id=get_lora_id()
  209. if lora_model_id is None else lora_model_id,
  210. rank=rank,
  211. lora_alpha=lora_alpha,
  212. tensors=tensors,
  213. device=device,
  214. dtype=dtype,
  215. embeddings=embeddings,
  216. target_embedding_padding=target_embedding_padding,
  217. embedding_modules=embedding_modules,
  218. embedding_padding_modules=embedding_padding_modules,
  219. )
  220. class LoRAModelManager:
  221. """A manager that manages multiple LoRA-fine-tuned models."""
  222. def __init__(
  223. self,
  224. model: nn.Module,
  225. max_num_seqs: int,
  226. max_num_batched_tokens: int,
  227. vocab_size: int,
  228. lora_config: LoRAConfig,
  229. ):
  230. """Create a LoRAModelManager and adapter for a given model.
  231. Args:
  232. model: the model to be adapted.
  233. max_num_seqs: the maximum number of sequences model can run in a
  234. single batch.
  235. max_num_batched_tokens: the maximum number of tokens model can run
  236. in a single batch.
  237. vocab_size: the vocab size of the model.
  238. lora_config: the LoRA configuration.
  239. """
  240. self.lora_config = lora_config
  241. self.max_num_seqs = max_num_seqs
  242. assert self.capacity >= self.lora_slots
  243. self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
  244. self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
  245. self.vocab_size = vocab_size
  246. self.base_indices = torch.empty(self.max_num_batched_tokens,
  247. dtype=torch.long,
  248. device="cuda")
  249. self.sampler_indices = torch.empty(self.max_num_batched_tokens,
  250. dtype=torch.long,
  251. device="cuda")
  252. self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
  253. dtype=torch.long,
  254. device="cuda")
  255. self.embeddings_indices = torch.empty(2,
  256. self.max_num_batched_tokens,
  257. dtype=torch.long,
  258. device="cuda")
  259. self.offsets = []
  260. # 4 is the number of indices tensors defined above
  261. # base_indices, sampler_indices, sampler_indices_padded,
  262. # embeddings_indices
  263. self.indices_len = [None] * 4
  264. self.model: nn.Module = model
  265. if hasattr(self.model, "supported_lora_modules"):
  266. self.supported_lora_modules = copy.deepcopy(
  267. self.model.supported_lora_modules)
  268. self.packed_modules_mapping = copy.deepcopy(
  269. self.model.packed_modules_mapping)
  270. self.packed_modules: Dict[str, List[str]] = {}
  271. self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
  272. self._registered_loras: Dict[int, LoRAModel] = {}
  273. # Dict instead of a Set for compatibility with LRUCache.
  274. self._active_loras: Dict[int, None] = {}
  275. self._last_mapping = None
  276. self._create_lora_modules()
  277. self.model.lora_manager = self
  278. @property
  279. def capacity(self) -> int:
  280. return self.lora_config.max_cpu_loras
  281. @property
  282. def lora_slots(self) -> int:
  283. return self.lora_config.max_loras
  284. def __len__(self) -> int:
  285. return len(self._registered_loras)
  286. def activate_lora(
  287. self,
  288. lora_id: int,
  289. ) -> bool:
  290. """Move LoRA into a GPU buffer to be used in the forward pass."""
  291. if lora_id in self._active_loras:
  292. return False
  293. first_free_slot = next(
  294. ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
  295. if lora_id is None), None)
  296. if first_free_slot is None:
  297. raise ValueError("No free lora slots")
  298. index, _ = first_free_slot
  299. self._active_loras[lora_id] = None
  300. lora_model = self._registered_loras[lora_id]
  301. logger.debug(
  302. f"Activating LoRA. int id: {lora_model.id}, slot index: {index}")
  303. self.lora_index_to_id[index] = lora_model.id
  304. for module_name, module in self.modules.items():
  305. module_lora = lora_model.get_lora(module_name)
  306. if module_lora:
  307. module_lora.optimize()
  308. module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
  309. module_lora.embeddings_tensor)
  310. else:
  311. module.reset_lora(index)
  312. return True
  313. def _deactivate_lora(self, lora_id: int):
  314. try:
  315. index = self.lora_index_to_id.index(lora_id)
  316. self.lora_index_to_id[index] = None
  317. except ValueError:
  318. pass
  319. def deactivate_lora(self, lora_id: int) -> bool:
  320. """Remove a LoRA from a GPU buffer."""
  321. if lora_id in self._active_loras:
  322. self._deactivate_lora(lora_id)
  323. self._active_loras.pop(lora_id)
  324. return True
  325. return False
  326. def _add_lora(self, lora: LoRAModel) -> bool:
  327. self._create_merged_loras_inplace(lora)
  328. self._registered_loras[lora.id] = lora
  329. def add_lora(self, lora: LoRAModel) -> bool:
  330. """Add a LoRAModel to the manager CPU cache."""
  331. if lora.id not in self._registered_loras:
  332. if len(self._registered_loras) >= self.capacity:
  333. raise RuntimeError("No free LoRA slots.")
  334. self._add_lora(lora)
  335. return True
  336. return False
  337. def remove_lora(self, lora_id: int) -> bool:
  338. """Remove a LoRAModel from the manager CPU cache."""
  339. # TODO: should we check active lora?
  340. self.deactivate_lora(lora_id)
  341. return bool(self._registered_loras.pop(lora_id, None))
  342. # TODO see if this can be vectorized
  343. def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
  344. (base_indices, sampler_indices, sampler_indices_padded,
  345. embeddings_indices,
  346. indices_len) = convert_mapping(mapping, self.lora_index_to_id,
  347. self.lora_slots + 1, self.vocab_size,
  348. self.lora_config.lora_extra_vocab_size)
  349. self.base_indices[:base_indices.shape[0]].copy_(base_indices)
  350. self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
  351. self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
  352. sampler_indices_padded)
  353. self.embeddings_indices[:embeddings_indices.
  354. shape[0], :embeddings_indices.shape[1]].copy_(
  355. embeddings_indices)
  356. # Maintain the reference
  357. self.indices_len[:] = indices_len
  358. def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
  359. if self._last_mapping != lora_mapping:
  360. self._set_lora_mapping(lora_mapping)
  361. self._last_mapping = lora_mapping
  362. def list_loras(self) -> Dict[int, LoRAModel]:
  363. """List all registered LoRAModels."""
  364. return dict(self._registered_loras)
  365. def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
  366. return self._registered_loras.get(lora_id, None)
  367. def remove_all_loras(self) -> bool:
  368. """Remove all LoRAModels from the manager."""
  369. self._registered_loras.clear()
  370. self.lora_index_to_id = [None] * self.lora_slots
  371. self._active_loras.clear()
  372. def _create_lora_modules(self):
  373. for module_name, module in self.model.named_modules():
  374. if not self._match_target_modules(module_name):
  375. continue
  376. new_module = replace_submodule(
  377. self.model, module_name,
  378. from_layer(module, self.lora_slots, self.lora_config,
  379. self.model.config))
  380. # (yard1): TODO make this more robust
  381. if "lm_head" in module_name:
  382. logits_processor_module = self.model.get_submodule(
  383. "logits_processor")
  384. new_module = replace_submodule(
  385. self.model, "logits_processor",
  386. from_layer_logits_processor(logits_processor_module,
  387. module, self.lora_slots,
  388. self.lora_config,
  389. self.model.config))
  390. self.register_module(module_name, new_module)
  391. self._register_packed_modules(module_name)
  392. new_module.set_mapping(self.base_indices, self.sampler_indices,
  393. self.sampler_indices_padded,
  394. self.embeddings_indices, self.indices_len)
  395. def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
  396. assert isinstance(module, BaseLayerWithLoRA)
  397. self.modules[module_name] = module
  398. def create_dummy_lora(
  399. self,
  400. lora_id: int,
  401. rank: int,
  402. embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
  403. """Create zero-initialized LoRAModel for warmup."""
  404. model = LoRAModel(lora_id, rank, {})
  405. for module_name, module in self.model.named_modules():
  406. if not self._match_target_modules(module_name) or not isinstance(
  407. module, BaseLayerWithLoRA):
  408. continue
  409. parts = module_name.split(".")
  410. if module_name not in self.packed_modules:
  411. if parts[-1] in embedding_modules:
  412. input_dim = (module.base_layer.org_vocab_size +
  413. self.lora_config.lora_extra_vocab_size if
  414. hasattr(module.base_layer, "org_vocab_size")
  415. else module.base_layer.weight.shape[1])
  416. output_dim = module.base_layer.embedding_dim if hasattr(
  417. module.base_layer,
  418. "embedding_dim") else module.base_layer.weight.shape[0]
  419. embeddings_tensor_dim = (module.base_layer.embedding_dim if
  420. hasattr(module.base_layer,
  421. "embedding_dim") else
  422. module.base_layer.weight.shape[1])
  423. lora = LoRALayerWeights.create_dummy_lora_weights(
  424. module_name,
  425. input_dim,
  426. output_dim,
  427. rank,
  428. module.lora_a_stacked.dtype,
  429. "cpu",
  430. embeddings_tensor_dim=embeddings_tensor_dim)
  431. else:
  432. lora = LoRALayerWeights.create_dummy_lora_weights(
  433. module_name,
  434. module.lora_a_stacked.shape[-1],
  435. module.lora_b_stacked.shape[-2],
  436. rank,
  437. module.lora_a_stacked.dtype,
  438. "cpu",
  439. )
  440. lora.optimize()
  441. else:
  442. parts = module_name.split(".")
  443. replacements = self.packed_modules_mapping[parts[-1]]
  444. subloras = []
  445. for i, r in enumerate(replacements):
  446. lora = LoRALayerWeights.create_dummy_lora_weights(
  447. module_name + "." + r,
  448. module.lora_a_stacked[i].shape[-1],
  449. module.lora_b_stacked[i].shape[-2],
  450. rank,
  451. module.lora_a_stacked[i].dtype,
  452. "cpu",
  453. )
  454. lora.optimize()
  455. subloras.append(lora)
  456. lora = PackedLoRALayerWeights.pack(subloras)
  457. model.loras[module_name] = lora
  458. return model
  459. def _match_target_modules(self, module_name: str):
  460. return any(
  461. re.match(
  462. r".*\.{target_module}$".format(target_module=target_module),
  463. module_name) or target_module == module_name
  464. for target_module in self.supported_lora_modules)
  465. def _register_packed_modules(self, module_full_name: str) -> None:
  466. parts = module_full_name.split(".")
  467. module_name = parts[-1]
  468. replacements = self.packed_modules_mapping.get(module_name)
  469. if not replacements:
  470. return
  471. prefix = ".".join(parts[:-1])
  472. self.packed_modules[module_full_name] = [
  473. prefix + "." + r if prefix else r for r in replacements
  474. ]
  475. def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
  476. for module_name, new_module_names in self.packed_modules.items():
  477. replacement_loras = []
  478. has_replacement = False
  479. for r in new_module_names:
  480. lora = lora_model.get_lora(r)
  481. replacement_loras.append(lora)
  482. if lora:
  483. has_replacement = True
  484. if not has_replacement:
  485. continue
  486. for i in range(len(replacement_loras)):
  487. if replacement_loras[i]:
  488. continue
  489. replacement_loras[i] = None
  490. lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
  491. replacement_loras)
  492. class LoRALRUCache(LRUCache[LoRAModel]):
  493. def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
  494. None]):
  495. super().__init__(capacity)
  496. self.deactivate_lora_fn = deactivate_lora_fn
  497. def _on_remove(self, key: Hashable, value: LoRAModel):
  498. logger.debug(f"Removing LoRA. int id: {key}")
  499. self.deactivate_lora_fn(key)
  500. return super()._on_remove(key, value)
  501. class LRUCacheLoRAModelManager(LoRAModelManager):
  502. """A model manager that manages multiple LoRAs with LRU cache."""
  503. def __init__(
  504. self,
  505. model: nn.Module,
  506. max_num_seqs: int,
  507. max_num_batched_tokens: int,
  508. vocab_size: int,
  509. lora_config: LoRAConfig,
  510. ):
  511. super().__init__(model, max_num_seqs, max_num_batched_tokens,
  512. vocab_size, lora_config)
  513. self._registered_loras: LoRALRUCache = LoRALRUCache(
  514. self.capacity, self.deactivate_lora)
  515. self._active_loras: LoRALRUCache = LoRALRUCache(
  516. self.lora_slots, self._deactivate_lora)
  517. def list_loras(self) -> Dict[int, LoRAModel]:
  518. """List all registered LoRAModels."""
  519. return dict(self._registered_loras.cache)
  520. def add_lora(self, lora: LoRAModel) -> bool:
  521. """Add a LoRAModel to the manager."""
  522. if lora.id not in self._registered_loras:
  523. self._add_lora(lora)
  524. was_added = True
  525. else:
  526. # We always touch to update the LRU cache order
  527. self._registered_loras.touch(lora.id)
  528. was_added = False
  529. return was_added
  530. def activate_lora(
  531. self,
  532. lora_id: int,
  533. ) -> bool:
  534. if lora_id not in self._active_loras and len(
  535. self._active_loras) >= self.lora_slots:
  536. self._active_loras.remove_oldest()
  537. result = super().activate_lora(lora_id)
  538. # We always touch to update the LRU cache order
  539. self._active_loras.touch(lora_id)
  540. return result
  541. def remove_oldest_lora(self) -> bool:
  542. if len(self._registered_loras) > 0:
  543. self._registered_loras.remove_oldest()
  544. return True
  545. return False
  546. def create_lora_manager(
  547. model: nn.Module,
  548. max_num_seqs: int,
  549. max_num_batched_tokens: int,
  550. vocab_size: int,
  551. lora_config: LoRAConfig,
  552. lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
  553. **kwargs) -> LoRAModelManager:
  554. """Create a LoRA adapter for a given model."""
  555. if not hasattr(model, "supported_lora_modules"):
  556. raise ValueError(f"Model {type(model)} is not supported for LoRA.")
  557. lora_manager = lora_manager_cls(
  558. model=model,
  559. max_num_seqs=max_num_seqs,
  560. max_num_batched_tokens=max_num_batched_tokens,
  561. vocab_size=vocab_size,
  562. lora_config=lora_config,
  563. **kwargs)
  564. return lora_manager