models.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. import copy
  2. import json
  3. import logging
  4. import math
  5. import os
  6. import re
  7. from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type)
  8. import safetensors.torch
  9. import torch
  10. from torch import nn
  11. from aphrodite.common.config import LoRAConfig
  12. from aphrodite.common.utils import LRUCache, in_wsl
  13. from aphrodite.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler
  14. from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
  15. from aphrodite.lora.utils import parse_fine_tuned_lora_name, replace_submodule
  16. logger = logging.getLogger(__name__)
  17. _GLOBAL_LORA_ID = 0
  18. def convert_mapping(
  19. mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
  20. max_loras: int, vocab_size: int, extra_vocab_size: int
  21. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
  22. """Converts LoRAMapping to index tensors.
  23. Args:
  24. mapping: LoRAMapping mapping rows in a batch to LoRA ids.
  25. lora_index_to_id: List mapping LoRA ids to LoRA indices.
  26. max_loras: Maximum number of LoRAs.
  27. vocab_size: Model vocab size.
  28. extra_vocab_size: Extra vocab size each LoRA can have.
  29. Returns:
  30. A tuple of tensors:
  31. base_indices: Tensor of shape [batch_size] mapping batch rows to
  32. LoRA indices.
  33. sampler_indices: Tensor of shape [batch_size] mapping requests to
  34. LoRA indices for sampler. For generation, this will be the
  35. same as base_indicies. For prefill, this will map requests
  36. to LoRA indices.
  37. sampler_indices_padded: Tensor of shape [batch_size] mapping
  38. requests to LoRA indices for sampler with padding.
  39. Same as sampler_indicies, but -1 is replaced with
  40. max_loras.
  41. embeddings_indices: Tensor of shape [2, batch_size] mapping
  42. requests to embedding indices. First row is for embeddings
  43. added by the LoRAs, second row is for the LoRA.lora_a
  44. embeddings.
  45. indices_len: List of lengths of the above tensors.
  46. """
  47. indices = list(mapping.index_mapping).copy()
  48. embedding_indices = indices.copy()
  49. lora_indices = indices.copy()
  50. prompt_mapping = [
  51. lora_index_to_id.index(x) if x > 0 else -1
  52. for x in mapping.prompt_mapping
  53. ]
  54. lora_idx = None
  55. for i in range(len(indices)):
  56. # TODO index can be slow. optimize
  57. lora_idx = (lora_index_to_id.index(indices[i])
  58. if indices[i] > 0 else -1)
  59. embedding_indices[i] = lora_idx if indices[i] > 0 else 0
  60. indices[i] = i
  61. lora_indices[i] = lora_idx
  62. indices = torch.tensor([indices, lora_indices, embedding_indices],
  63. dtype=torch.long,
  64. device="cuda")
  65. prompt_mapping = torch.tensor(prompt_mapping,
  66. device="cuda",
  67. dtype=torch.long)
  68. embeddings_indices = torch.stack([
  69. indices[2] * extra_vocab_size,
  70. indices[2] * (vocab_size + extra_vocab_size)
  71. ])
  72. embeddings_indices[embeddings_indices == -1] = max_loras - 1
  73. base_indices = indices[1]
  74. sampler_indices = prompt_mapping
  75. sampler_indices_padded = sampler_indices.clone()
  76. sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
  77. sampler_indices_padded = (
  78. torch.arange(
  79. 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
  80. (sampler_indices_padded * len(sampler_indices_padded)))
  81. indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
  82. sampler_indices_padded.shape[-1],
  83. embeddings_indices.shape[-1])
  84. return (base_indices, sampler_indices, sampler_indices_padded,
  85. embeddings_indices, indices_len)
  86. def get_lora_id():
  87. global _GLOBAL_LORA_ID
  88. _GLOBAL_LORA_ID += 1
  89. return _GLOBAL_LORA_ID
  90. class LoRAModel:
  91. """A LoRA fine-tuned model."""
  92. def __init__(
  93. self,
  94. lora_model_id: int,
  95. rank: int,
  96. loras: Dict[str, LoRALayerWeights],
  97. ) -> None:
  98. self.id = lora_model_id
  99. assert (lora_model_id >
  100. 0), f"a valid lora id should be greater than 0, got {self.id}"
  101. self.rank = rank
  102. self.loras: Dict[str, LoRALayerWeights] = loras
  103. @property
  104. def extra_vocab_size(self) -> int:
  105. return max(lora.extra_vocab_size
  106. for lora in self.loras.values()) if self.loras else 0
  107. def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
  108. """Get LoRA for a given module by name"""
  109. return self.loras.get(module_name, None)
  110. # (yard1): TODO see if we can derive target_embedding_padding automatically
  111. @classmethod
  112. def from_lora_tensors(
  113. cls,
  114. lora_model_id: int,
  115. rank: int,
  116. lora_alpha: int,
  117. tensors: Dict[str, torch.Tensor],
  118. device: str = "cuda",
  119. dtype: Optional[torch.dtype] = None,
  120. embeddings: Optional[Dict[str, torch.Tensor]] = None,
  121. target_embedding_padding: Optional[int] = None,
  122. embedding_modules: Optional[Dict[str, str]] = None,
  123. embedding_padding_modules: Optional[List[str]] = None,
  124. ) -> "LoRAModel":
  125. """Create a LoRAModel from a dictionary of tensors."""
  126. pin_memory = str(device) == "cpu" and not in_wsl()
  127. loras: Dict[str, LoRALayerWeights] = {}
  128. for tensor_name, tensor in tensors.items():
  129. result = parse_fine_tuned_lora_name(tensor_name)
  130. if result is not None:
  131. module_name, is_lora_a = result
  132. if module_name not in loras:
  133. lora_embeddings_tensor = None
  134. if embeddings:
  135. embeddings_module = next(
  136. (k for k in embedding_modules if k in module_name),
  137. None)
  138. if embeddings_module:
  139. lora_embeddings_tensor = embeddings[
  140. embedding_modules[embeddings_module]].to(
  141. device=device, dtype=dtype)
  142. if pin_memory:
  143. lora_embeddings_tensor = (
  144. lora_embeddings_tensor.pin_memory())
  145. loras[module_name] = LoRALayerWeights(
  146. module_name, rank, lora_alpha, None, None,
  147. lora_embeddings_tensor)
  148. if is_lora_a:
  149. loras[module_name].lora_a = tensor.to(device=device,
  150. dtype=dtype).t()
  151. if pin_memory:
  152. loras[module_name].lora_a = loras[
  153. module_name].lora_a.pin_memory()
  154. else:
  155. loras[module_name].lora_b = tensor.to(device=device,
  156. dtype=dtype).t()
  157. if any(name in module_name
  158. for name in embedding_padding_modules
  159. ) and target_embedding_padding is not None:
  160. lora_b = loras[module_name].lora_b
  161. assert target_embedding_padding >= lora_b.shape[1]
  162. addition = target_embedding_padding - lora_b.shape[1]
  163. loras[module_name].lora_b = torch.nn.functional.pad(
  164. lora_b, (0, addition))
  165. if pin_memory:
  166. loras[module_name].lora_b = loras[
  167. module_name].lora_b.pin_memory()
  168. for lora in loras.values():
  169. lora.optimize()
  170. return cls(lora_model_id, rank, loras)
  171. @classmethod
  172. def from_local_checkpoint(
  173. cls,
  174. lora_dir: str,
  175. lora_model_id: Optional[int] = None,
  176. device: str = "cuda",
  177. dtype: Optional[torch.dtype] = None,
  178. target_embedding_padding: Optional[int] = None,
  179. embedding_modules: Optional[Dict[str, str]] = None,
  180. embedding_padding_modules: Optional[List[str]] = None,
  181. ) -> "LoRAModel":
  182. """Create a LoRAModel from a local checkpoint."""
  183. lora_config_path = os.path.join(lora_dir, "adapter_config.json")
  184. lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
  185. lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
  186. new_embeddings_tensor_path = os.path.join(
  187. lora_dir, "new_embeddings.safetensors")
  188. new_embeddings_bin_file_path = os.path.join(lora_dir,
  189. "new_embeddings.bin")
  190. if os.path.isfile(lora_tensor_path):
  191. tensors = safetensors.torch.load_file(lora_tensor_path)
  192. elif os.path.isfile(lora_bin_file_path):
  193. tensors = torch.load(lora_bin_file_path)
  194. else:
  195. raise ValueError(f"{lora_dir} doesn't contain tensors")
  196. embeddings = None
  197. if os.path.isfile(new_embeddings_tensor_path):
  198. embeddings = safetensors.torch.load_file(
  199. new_embeddings_tensor_path)
  200. elif os.path.isfile(new_embeddings_bin_file_path):
  201. embeddings = torch.load(new_embeddings_bin_file_path)
  202. with open(lora_config_path) as f:
  203. config = json.load(f)
  204. rank = config["r"]
  205. lora_alpha = config["lora_alpha"]
  206. return cls.from_lora_tensors(
  207. lora_model_id=get_lora_id()
  208. if lora_model_id is None else lora_model_id,
  209. rank=rank,
  210. lora_alpha=lora_alpha,
  211. tensors=tensors,
  212. device=device,
  213. dtype=dtype,
  214. embeddings=embeddings,
  215. target_embedding_padding=target_embedding_padding,
  216. embedding_modules=embedding_modules,
  217. embedding_padding_modules=embedding_padding_modules,
  218. )
  219. class LoRAModelManager:
  220. """A manager that manages multiple LoRA-fine-tuned models."""
  221. def __init__(
  222. self,
  223. model: nn.Module,
  224. max_num_seqs: int,
  225. max_num_batched_tokens: int,
  226. vocab_size: int,
  227. lora_config: LoRAConfig,
  228. ):
  229. """Create a LoRAModelManager and adapter for a given model.
  230. Args:
  231. model: the model to be adapted.
  232. max_num_seqs: the maximum number of sequences model can run in a
  233. single batch.
  234. max_num_batched_tokens: the maximum number of tokens model can run
  235. in a single batch.
  236. vocab_size: the vocab size of the model.
  237. lora_config: the LoRA configuration.
  238. """
  239. self.lora_config = lora_config
  240. self.max_num_seqs = max_num_seqs
  241. assert self.capacity >= self.lora_slots
  242. self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
  243. self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
  244. self.vocab_size = vocab_size
  245. self.base_indices = torch.empty(self.max_num_batched_tokens,
  246. dtype=torch.long,
  247. device="cuda")
  248. self.sampler_indices = torch.empty(self.max_num_batched_tokens,
  249. dtype=torch.long,
  250. device="cuda")
  251. self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
  252. dtype=torch.long,
  253. device="cuda")
  254. self.embeddings_indices = torch.empty(2,
  255. self.max_num_batched_tokens,
  256. dtype=torch.long,
  257. device="cuda")
  258. self.offsets = []
  259. # 4 is the number of indicies tensors defined above
  260. # base_indices, sampler_indices, sampler_indices_padded,
  261. # embeddings_indices
  262. self.indices_len = [None] * 4
  263. self.model: nn.Module = model
  264. if hasattr(self.model, "supported_lora_modules"):
  265. self.supported_lora_modules = copy.deepcopy(
  266. self.model.supported_lora_modules)
  267. self.packed_modules_mapping = copy.deepcopy(
  268. self.model.packed_modules_mapping)
  269. self.packed_modules: Dict[str, List[str]] = {}
  270. self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
  271. self._registered_loras: Dict[int, LoRAModel] = {}
  272. # Dict instead of a Set for compatibility with LRUCache.
  273. self._active_loras: Dict[int, None] = {}
  274. self._last_mapping = None
  275. self._create_lora_modules()
  276. self.model.lora_manager = self
  277. @property
  278. def capacity(self) -> int:
  279. return self.lora_config.max_cpu_loras
  280. @property
  281. def lora_slots(self) -> int:
  282. return self.lora_config.max_loras
  283. def __len__(self) -> int:
  284. return len(self._registered_loras)
  285. def activate_lora(
  286. self,
  287. lora_id: int,
  288. ) -> bool:
  289. """Move LoRA into a GPU buffer to be used in the forward pass."""
  290. if lora_id in self._active_loras:
  291. return False
  292. first_free_slot = next(
  293. ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
  294. if lora_id is None), None)
  295. if first_free_slot is None:
  296. raise ValueError("No free lora slots")
  297. index, _ = first_free_slot
  298. self._active_loras[lora_id] = None
  299. lora_model = self._registered_loras[lora_id]
  300. logger.debug(
  301. f"Activating LoRA. int id: {lora_model.id}, slot index: {index}")
  302. self.lora_index_to_id[index] = lora_model.id
  303. for module_name, module in self.modules.items():
  304. module_lora = lora_model.get_lora(module_name)
  305. if module_lora:
  306. module_lora.optimize()
  307. module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
  308. module_lora.embeddings_tensor)
  309. else:
  310. module.reset_lora(index)
  311. return True
  312. def _deactivate_lora(self, lora_id: int):
  313. try:
  314. index = self.lora_index_to_id.index(lora_id)
  315. self.lora_index_to_id[index] = None
  316. except ValueError:
  317. pass
  318. def deactivate_lora(self, lora_id: int) -> bool:
  319. """Remove a LoRA from a GPU buffer."""
  320. if lora_id in self._active_loras:
  321. self._deactivate_lora(lora_id)
  322. self._active_loras.pop(lora_id)
  323. return True
  324. return False
  325. def _add_lora(self, lora: LoRAModel) -> bool:
  326. self._create_merged_loras_inplace(lora)
  327. self._registered_loras[lora.id] = lora
  328. def add_lora(self, lora: LoRAModel) -> bool:
  329. """Add a LoRAModel to the manager CPU cache."""
  330. if lora.id not in self._registered_loras:
  331. if len(self._registered_loras) >= self.capacity:
  332. raise RuntimeError("No free LoRA slots.")
  333. self._add_lora(lora)
  334. return True
  335. return False
  336. def remove_lora(self, lora_id: int) -> bool:
  337. """Remove a LoRAModel from the manager CPU cache."""
  338. # TODO: should we check active lora?
  339. self.deactivate_lora(lora_id)
  340. return bool(self._registered_loras.pop(lora_id, None))
  341. # TODO see if this can be vectorized
  342. def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
  343. (base_indices, sampler_indices, sampler_indices_padded,
  344. embeddings_indices,
  345. indices_len) = convert_mapping(mapping, self.lora_index_to_id,
  346. self.lora_slots + 1, self.vocab_size,
  347. self.lora_config.lora_extra_vocab_size)
  348. self.base_indices[:base_indices.shape[0]].copy_(base_indices)
  349. self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
  350. self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
  351. sampler_indices_padded)
  352. self.embeddings_indices[:embeddings_indices.
  353. shape[0], :embeddings_indices.shape[1]].copy_(
  354. embeddings_indices)
  355. # Maintain the reference
  356. self.indices_len[:] = indices_len
  357. def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
  358. if self._last_mapping != lora_mapping:
  359. self._set_lora_mapping(lora_mapping)
  360. self._last_mapping = lora_mapping
  361. def list_loras(self) -> Dict[int, LoRAModel]:
  362. """List all registered LoRAModels."""
  363. return dict(self._registered_loras)
  364. def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
  365. return self._registered_loras.get(lora_id, None)
  366. def remove_all_loras(self) -> bool:
  367. """Remove all LoRAModels from the manager."""
  368. self._registered_loras.clear()
  369. self.lora_index_to_id = [None] * self.lora_slots
  370. self._active_loras.clear()
  371. def _create_lora_modules(self):
  372. for module_name, module in self.model.named_modules():
  373. if not self._match_target_modules(module_name):
  374. continue
  375. new_module = replace_submodule(
  376. self.model, module_name,
  377. from_layer(module, self.lora_slots, self.lora_config,
  378. self.model.config))
  379. # (yard1): TODO make this more robust
  380. if "lm_head" in module_name:
  381. sampler_module = self.model.get_submodule("sampler")
  382. new_module = replace_submodule(
  383. self.model, "sampler",
  384. from_layer_sampler(sampler_module, module, self.lora_slots,
  385. self.lora_config, self.model.config))
  386. self.register_module(module_name, new_module)
  387. self._register_packed_modules(module_name)
  388. new_module.set_mapping(self.base_indices, self.sampler_indices,
  389. self.sampler_indices_padded,
  390. self.embeddings_indices, self.indices_len)
  391. def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
  392. assert isinstance(module, BaseLayerWithLoRA)
  393. self.modules[module_name] = module
  394. def create_dummy_lora(
  395. self,
  396. lora_id: int,
  397. rank: int,
  398. embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
  399. """Create zero-initialized LoRAModel for warmup."""
  400. model = LoRAModel(lora_id, rank, {})
  401. for module_name, module in self.model.named_modules():
  402. if not self._match_target_modules(module_name) or not isinstance(
  403. module, BaseLayerWithLoRA):
  404. continue
  405. parts = module_name.split(".")
  406. if module_name not in self.packed_modules:
  407. if parts[-1] in embedding_modules:
  408. input_dim = (module.base_layer.org_vocab_size +
  409. self.lora_config.lora_extra_vocab_size if
  410. hasattr(module.base_layer, "org_vocab_size")
  411. else module.base_layer.weight.shape[1])
  412. output_dim = module.base_layer.embedding_dim if hasattr(
  413. module.base_layer,
  414. "embedding_dim") else module.base_layer.weight.shape[0]
  415. embeddings_tensor_dim = (module.base_layer.embedding_dim if
  416. hasattr(module.base_layer,
  417. "embedding_dim") else
  418. module.base_layer.weight.shape[1])
  419. lora = LoRALayerWeights.create_dummy_lora_weights(
  420. module_name,
  421. input_dim,
  422. output_dim,
  423. rank,
  424. module.lora_a_stacked.dtype,
  425. "cpu",
  426. embeddings_tensor_dim=embeddings_tensor_dim)
  427. else:
  428. lora = LoRALayerWeights.create_dummy_lora_weights(
  429. module_name,
  430. module.lora_a_stacked.shape[-1],
  431. module.lora_b_stacked.shape[-2],
  432. rank,
  433. module.lora_a_stacked.dtype,
  434. "cpu",
  435. )
  436. lora.optimize()
  437. else:
  438. parts = module_name.split(".")
  439. replacements = self.packed_modules_mapping[parts[-1]]
  440. subloras = []
  441. for i, r in enumerate(replacements):
  442. lora = LoRALayerWeights.create_dummy_lora_weights(
  443. module_name + "." + r,
  444. module.lora_a_stacked[i].shape[-1],
  445. module.lora_b_stacked[i].shape[-2],
  446. rank,
  447. module.lora_a_stacked[i].dtype,
  448. "cpu",
  449. )
  450. lora.optimize()
  451. subloras.append(lora)
  452. lora = PackedLoRALayerWeights.pack(subloras)
  453. model.loras[module_name] = lora
  454. return model
  455. def _match_target_modules(self, module_name: str):
  456. return any(
  457. re.match(
  458. r".*\.{target_module}$".format(target_module=target_module),
  459. module_name) or target_module == module_name
  460. for target_module in self.supported_lora_modules)
  461. def _register_packed_modules(self, module_full_name: str) -> None:
  462. parts = module_full_name.split(".")
  463. module_name = parts[-1]
  464. replacements = self.packed_modules_mapping.get(module_name)
  465. if not replacements:
  466. return
  467. prefix = ".".join(parts[:-1])
  468. self.packed_modules[module_full_name] = [
  469. prefix + "." + r if prefix else r for r in replacements
  470. ]
  471. def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
  472. for module_name, new_module_names in self.packed_modules.items():
  473. replacement_loras = []
  474. has_replacement = False
  475. for r in new_module_names:
  476. lora = lora_model.get_lora(r)
  477. replacement_loras.append(lora)
  478. if lora:
  479. has_replacement = True
  480. if not has_replacement:
  481. continue
  482. for i in range(len(replacement_loras)):
  483. if replacement_loras[i]:
  484. continue
  485. replacement_loras[i] = None
  486. lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
  487. replacement_loras)
  488. class LoRALRUCache(LRUCache):
  489. def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
  490. None]):
  491. super().__init__(capacity)
  492. self.deactivate_lora_fn = deactivate_lora_fn
  493. def _on_remove(self, key: Hashable, value: Any):
  494. logger.debug(f"Removing LoRA. int id: {key}")
  495. self.deactivate_lora_fn(key)
  496. return super()._on_remove(key, value)
  497. class LRUCacheLoRAModelManager(LoRAModelManager):
  498. """A model manager that manages multiple LoRAs with LRU cache."""
  499. def __init__(
  500. self,
  501. model: nn.Module,
  502. max_num_seqs: int,
  503. max_num_batched_tokens: int,
  504. vocab_size: int,
  505. lora_config: LoRAConfig,
  506. ):
  507. super().__init__(model, max_num_seqs, max_num_batched_tokens,
  508. vocab_size, lora_config)
  509. self._registered_loras: LoRALRUCache = LoRALRUCache(
  510. self.capacity, self.deactivate_lora)
  511. self._active_loras: LoRALRUCache = LoRALRUCache(
  512. self.lora_slots, self._deactivate_lora)
  513. def list_loras(self) -> Dict[int, LoRAModel]:
  514. """List all registered LoRAModels."""
  515. return dict(self._registered_loras.cache)
  516. def add_lora(self, lora: LoRAModel) -> bool:
  517. """Add a LoRAModel to the manager."""
  518. if lora.id not in self._registered_loras:
  519. self._add_lora(lora)
  520. was_added = True
  521. else:
  522. # We always touch to update the LRU cache order
  523. self._registered_loras.touch(lora.id)
  524. was_added = False
  525. return was_added
  526. def activate_lora(
  527. self,
  528. lora_id: int,
  529. ) -> bool:
  530. if lora_id not in self._active_loras and len(
  531. self._active_loras) >= self.lora_slots:
  532. self._active_loras.remove_oldest()
  533. result = super().activate_lora(lora_id)
  534. # We always touch to update the LRU cache order
  535. self._active_loras.touch(lora_id)
  536. return result
  537. def remove_oldest_lora(self) -> bool:
  538. if len(self._registered_loras) > 0:
  539. self._registered_loras.remove_oldest()
  540. return True
  541. return False
  542. def create_lora_manager(
  543. model: nn.Module,
  544. max_num_seqs: int,
  545. max_num_batched_tokens: int,
  546. vocab_size: int,
  547. lora_config: LoRAConfig,
  548. lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
  549. **kwargs) -> LoRAModelManager:
  550. """Create a LoRA adapter for a given model."""
  551. if not hasattr(model, "supported_lora_modules"):
  552. raise ValueError(f"Model {type(model)} is not supported for LoRA.")
  553. lora_manager = lora_manager_cls(
  554. model=model,
  555. max_num_seqs=max_num_seqs,
  556. max_num_batched_tokens=max_num_batched_tokens,
  557. vocab_size=vocab_size,
  558. lora_config=lora_config,
  559. **kwargs)
  560. return lora_manager