models.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723
  1. import copy
  2. import json
  3. import math
  4. import os
  5. import re
  6. from dataclasses import dataclass, field
  7. from typing import Any, Callable, Dict, List, Optional, Type
  8. import safetensors.torch
  9. import torch
  10. from loguru import logger
  11. from torch import nn
  12. from aphrodite.adapter_commons.models import (AdapterLRUCache, AdapterModel,
  13. AdapterModelManager)
  14. from aphrodite.adapter_commons.utils import (add_adapter, deactivate_adapter,
  15. get_adapter, list_adapters,
  16. remove_adapter,
  17. set_adapter_mapping)
  18. from aphrodite.common.config import LoRAConfig
  19. from aphrodite.common.utils import is_pin_memory_available
  20. from aphrodite.lora.layers import (BaseLayerWithLoRA,
  21. LinearScalingRotaryEmbeddingWithLora,
  22. LoRAMapping, ModulesToSaveWrapper)
  23. from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
  24. from aphrodite.lora.punica import PunicaWrapper
  25. from aphrodite.lora.utils import (from_layer, from_layer_logits_processor,
  26. parse_fine_tuned_lora_name,
  27. replace_submodule)
  28. from aphrodite.modeling.models.interfaces import SupportsLoRA
  29. from aphrodite.modeling.models.utils import PPMissingLayer
  30. _GLOBAL_LORA_ID = 0
  31. @dataclass
  32. class LongContextLoRAContext:
  33. """Context for lora adapters that support long context."""
  34. # The scaling factors to support long context lora fine tuned models.
  35. scaling_factors: List[float]
  36. # dimension to apply rotary embedding.
  37. rot_dim: int
  38. # offsets to the sin_cos_cache for each lora_id loaded.
  39. # This value is dynamically modified.
  40. offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
  41. def get_lora_id():
  42. global _GLOBAL_LORA_ID
  43. _GLOBAL_LORA_ID += 1
  44. return _GLOBAL_LORA_ID
  45. class LoRAModel(AdapterModel):
  46. """A LoRA fine-tuned model."""
  47. def __init__(
  48. self,
  49. lora_model_id: int,
  50. rank: int,
  51. loras: Dict[str, LoRALayerWeights],
  52. scaling_factor: Optional[float] = None,
  53. ) -> None:
  54. """
  55. Args:
  56. lora_model_id: The integer id for the lora model.
  57. rank: lora rank.
  58. loras: module name -> weights for lora-replaced layers.
  59. scaling_factor: Scaling factor to support long context lora model.
  60. None if the lora is not tuned for long context support.
  61. """
  62. self.id = lora_model_id
  63. # Scaling factor for long context lora model. None if it is not
  64. # fine tuned for the long context.
  65. self.scaling_factor = scaling_factor
  66. assert (lora_model_id >
  67. 0), f"a valid lora id should be greater than 0, got {self.id}"
  68. self.rank = rank
  69. self.loras: Dict[str, LoRALayerWeights] = loras
  70. def clone(self, lora_model_id: int) -> "LoRAModel":
  71. """Return a copy of the object with different ids.
  72. Will share the underlying tensors."""
  73. return self.__class__(
  74. lora_model_id,
  75. rank=self.rank,
  76. loras=self.loras.copy(),
  77. )
  78. @property
  79. def extra_vocab_size(self) -> int:
  80. return max(lora.extra_vocab_size
  81. for lora in self.loras.values()) if self.loras else 0
  82. def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
  83. """Get LoRA for a given module by name"""
  84. return self.loras.get(module_name, None)
  85. # (yard1): TODO see if we can derive target_embedding_padding automatically
  86. @classmethod
  87. def from_lora_tensors(
  88. cls,
  89. lora_model_id: int,
  90. rank: int,
  91. lora_alpha: int,
  92. tensors: Dict[str, torch.Tensor],
  93. enable_lora_modules_to_save: bool = False,
  94. device: str = "cuda",
  95. dtype: Optional[torch.dtype] = None,
  96. embeddings: Optional[Dict[str, torch.Tensor]] = None,
  97. target_embedding_padding: Optional[int] = None,
  98. scaling_factor: Optional[float] = None,
  99. embedding_modules: Optional[Dict[str, str]] = None,
  100. embedding_padding_modules: Optional[List[str]] = None,
  101. ) -> "LoRAModel":
  102. """Create a LoRAModel from a dictionary of tensors."""
  103. pin_memory = str(device) == "cpu" and is_pin_memory_available()
  104. loras: Dict[str, LoRALayerWeights] = {}
  105. for tensor_name, tensor in tensors.items():
  106. module_name, is_lora_a = parse_fine_tuned_lora_name(
  107. tensor_name, enable_lora_modules_to_save)
  108. if module_name not in loras:
  109. lora_embeddings_tensor = None
  110. if embeddings:
  111. assert embedding_modules is not None
  112. embeddings_module = next(
  113. (k for k in embedding_modules if k in module_name),
  114. None)
  115. if embeddings_module:
  116. lora_embeddings_tensor = embeddings[
  117. embedding_modules[embeddings_module]].to(
  118. device=device, dtype=dtype)
  119. if pin_memory:
  120. lora_embeddings_tensor = (
  121. lora_embeddings_tensor.pin_memory())
  122. loras[module_name] = LoRALayerWeights(module_name, rank,
  123. lora_alpha, None, None,
  124. lora_embeddings_tensor)
  125. if is_lora_a:
  126. loras[module_name].lora_a = tensor.to(device=device,
  127. dtype=dtype).t()
  128. if pin_memory:
  129. loras[module_name].lora_a_pin_memory()
  130. elif is_lora_a is None: # this is modules_to_save tensor
  131. loras[module_name].lora_b = tensor.to(device=device,
  132. dtype=dtype)
  133. if pin_memory:
  134. loras[module_name].lora_b_pin_memory()
  135. else:
  136. loras[module_name].lora_b = tensor.to(device=device,
  137. dtype=dtype).t()
  138. assert embedding_padding_modules is not None
  139. if any(name in module_name
  140. for name in embedding_padding_modules
  141. ) and target_embedding_padding is not None:
  142. lora_b = loras[module_name].lora_b
  143. assert target_embedding_padding >= lora_b.shape[1]
  144. addition = target_embedding_padding - lora_b.shape[1]
  145. loras[module_name].lora_b = torch.nn.functional.pad(
  146. lora_b, (0, addition))
  147. if pin_memory:
  148. loras[module_name].lora_b = loras[
  149. module_name].lora_b.pin_memory()
  150. for lora in loras.values():
  151. lora.optimize()
  152. return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
  153. @classmethod
  154. def from_local_checkpoint(
  155. cls,
  156. lora_dir: str,
  157. expected_lora_modules: List[str],
  158. expected_modules_to_save: List[str],
  159. *,
  160. max_position_embeddings: Optional[int] = None,
  161. lora_model_id: Optional[int] = None,
  162. device: str = "cuda",
  163. enable_lora_modules_to_save: bool = False,
  164. dtype: Optional[torch.dtype] = None,
  165. target_embedding_padding: Optional[int] = None,
  166. embedding_modules: Optional[Dict[str, str]] = None,
  167. embedding_padding_modules: Optional[List[str]] = None,
  168. ) -> "LoRAModel":
  169. """Create a LoRAModel from a local checkpoint.
  170. Args:
  171. lora_dir: The local path that has lora data.
  172. expected_lora_modules: Name of modules that are expected to be
  173. replaced by lora.
  174. max_position_embeddings: Max position embedding length. Used to
  175. scaling the largest context length. If None, the lora model's
  176. context length is not scaled.
  177. lora_model_id: Lora model id. If not given, automatically set by
  178. a global counter.
  179. device: Device where the lora model is loaded.
  180. dtype: dtype of the lora model weights.
  181. Returns:
  182. Loaded LoRA Model.
  183. """
  184. lora_config_path = os.path.join(lora_dir, "adapter_config.json")
  185. lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
  186. lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
  187. new_embeddings_tensor_path = os.path.join(
  188. lora_dir, "new_embeddings.safetensors")
  189. new_embeddings_bin_file_path = os.path.join(lora_dir,
  190. "new_embeddings.bin")
  191. with open(lora_config_path) as f:
  192. config = json.load(f)
  193. if os.path.isfile(lora_tensor_path):
  194. tensors: Dict[str, torch.Tensor] = {}
  195. # Find unexpected modules.
  196. # Use safetensor key as a source of truth to find expected modules.
  197. # in peft if you have target_modules A, B, C and C does not exist
  198. # in the model it won’t error and model will be trained with A, B
  199. # loraified. C won’t exist in the safetensor but it will exist in
  200. # the target_modules of the adapter_config.json.
  201. unexpected_modules = []
  202. with safetensors.safe_open(lora_tensor_path,
  203. framework="pt") as f: # type: ignore
  204. for lora_module in f.keys(): # noqa
  205. module_name, is_lora_a = parse_fine_tuned_lora_name(
  206. lora_module, enable_lora_modules_to_save)
  207. part_name = module_name.split(".")[-1]
  208. is_expected_module_to_save = (is_lora_a is None) and (
  209. part_name in expected_modules_to_save)
  210. if (part_name not in expected_lora_modules
  211. ) and not is_expected_module_to_save:
  212. unexpected_modules.append(module_name)
  213. if unexpected_modules:
  214. raise ValueError(
  215. f"While loading {lora_dir}, expected"
  216. f" target modules in {expected_lora_modules}"
  217. f" but received {unexpected_modules}."
  218. f" Please verify that the loaded LoRA module is correct"
  219. )
  220. # Load tensors if there are only expected modules.
  221. for module in f.keys(): # noqa
  222. tensors[module] = f.get_tensor(module)
  223. elif os.path.isfile(lora_bin_file_path):
  224. # When a bin file is provided, we rely on config to find unexpected
  225. # modules.
  226. unexpected_modules = []
  227. target_modules = config["target_modules"]
  228. for module in target_modules:
  229. # Compatible with more modules,
  230. # such as:layers.11.self_attn.k_proj
  231. part_name = module.split(".")[-1]
  232. if part_name not in expected_lora_modules:
  233. unexpected_modules.append(module)
  234. # loaded lora's target modules must be a subset of
  235. # expected_lora_modules. It is not reliable. See
  236. # https://github.com/vllm-project/vllm/pull/5909. But there's no
  237. # other better mechanism.
  238. if unexpected_modules:
  239. print(unexpected_modules, "modules")
  240. raise ValueError(
  241. f"While loading {lora_dir}, expected"
  242. f" target modules in {expected_lora_modules}"
  243. f" but received {unexpected_modules}."
  244. f" Please verify that the loaded LoRA module is correct")
  245. tensors = torch.load(lora_bin_file_path, map_location=device)
  246. else:
  247. raise ValueError(f"{lora_dir} doesn't contain tensors")
  248. embeddings = None
  249. if os.path.isfile(new_embeddings_tensor_path):
  250. embeddings = safetensors.torch.load_file(
  251. new_embeddings_tensor_path)
  252. elif os.path.isfile(new_embeddings_bin_file_path):
  253. embeddings = torch.load(new_embeddings_bin_file_path,
  254. map_location=device)
  255. rank = config["r"]
  256. lora_alpha = config["lora_alpha"] * math.sqrt(rank) if config.get(
  257. "use_rslora", False) else config["lora_alpha"]
  258. context_length = config.get("context_length", None)
  259. scaling_factor = None
  260. if context_length:
  261. if max_position_embeddings is None:
  262. max_position_embeddings = context_length
  263. scaling_factor = float(
  264. math.ceil(context_length / max_position_embeddings))
  265. return cls.from_lora_tensors(
  266. lora_model_id=get_lora_id()
  267. if lora_model_id is None else lora_model_id,
  268. rank=rank,
  269. lora_alpha=lora_alpha,
  270. tensors=tensors,
  271. enable_lora_modules_to_save=enable_lora_modules_to_save,
  272. device=device,
  273. dtype=dtype,
  274. embeddings=embeddings,
  275. target_embedding_padding=target_embedding_padding,
  276. scaling_factor=scaling_factor,
  277. embedding_modules=embedding_modules,
  278. embedding_padding_modules=embedding_padding_modules,
  279. )
  280. class LoRAModelManager(AdapterModelManager):
  281. """A manager that manages multiple LoRA-fine-tuned models."""
  282. def __init__(
  283. self,
  284. model: SupportsLoRA,
  285. max_num_seqs: int,
  286. max_num_batched_tokens: int,
  287. vocab_size: int,
  288. lora_config: LoRAConfig,
  289. ):
  290. """Create a LoRAModelManager and adapter for a given model.
  291. Args:
  292. model: the model to be adapted.
  293. max_num_seqs: the maximum number of sequences model can run in a
  294. single batch.
  295. max_num_batched_tokens: the maximum number of tokens model can run
  296. in a single batch.
  297. vocab_size: the vocab size of the model.
  298. lora_config: the LoRA configuration.
  299. """
  300. self.lora_config = lora_config
  301. self.max_num_seqs = max_num_seqs
  302. assert self.capacity >= self.lora_slots
  303. self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
  304. self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
  305. self.vocab_size = vocab_size
  306. self.long_lora_context: Optional[LongContextLoRAContext] = None
  307. self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
  308. max_batches=self.max_num_seqs,
  309. device="cuda")
  310. # Scaling factor -> offset to the sin_cos_cache to it.
  311. # Used for long context lora.
  312. self.scaling_factor_to_offset: Dict[float, int] = {}
  313. super().__init__(model)
  314. if hasattr(self.model, "supported_lora_modules"):
  315. self.supported_lora_modules = copy.deepcopy(
  316. self.model.supported_lora_modules)
  317. if lora_config.long_lora_scaling_factors:
  318. # We need to replace rotary emb layer to do batch computation
  319. # for long lora.
  320. self.supported_lora_modules.append("rotary_emb")
  321. self.packed_modules_mapping = copy.deepcopy(
  322. self.model.packed_modules_mapping)
  323. self.packed_modules: Dict[str, List[str]] = {}
  324. self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
  325. # Dict instead of a Set for compatibility with LRUCache.
  326. self._last_mapping: Optional[LoRAMapping] = None
  327. self._create_lora_modules()
  328. self.model.lora_manager = self
  329. self.adapter_type = 'LoRa'
  330. @property
  331. def capacity(self) -> int:
  332. return self.lora_config.max_cpu_loras
  333. @property
  334. def lora_slots(self) -> int:
  335. return self.lora_config.max_loras
  336. @property
  337. def adapter_slots(self) -> int:
  338. return self.lora_slots
  339. def activate_adapter(
  340. self,
  341. lora_id: int,
  342. ) -> bool:
  343. """Move LoRA into a GPU buffer to be used in the forward pass."""
  344. if lora_id in self._active_adapters:
  345. return False
  346. first_free_slot = next(
  347. ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
  348. if lora_id is None), None)
  349. if first_free_slot is None:
  350. raise ValueError("No free lora slots")
  351. index, _ = first_free_slot
  352. self._active_adapters[lora_id] = None
  353. lora_model = self._registered_adapters[lora_id]
  354. logger.debug(f"Activating LoRA. int id: {lora_model.id}, "
  355. f"slot index: {index}")
  356. self.lora_index_to_id[index] = lora_model.id
  357. for module_name, module in self.modules.items():
  358. module_lora = lora_model.get_lora(module_name)
  359. if module_lora:
  360. module_lora.optimize()
  361. module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
  362. module_lora.embeddings_tensor)
  363. else:
  364. module.reset_lora(index)
  365. return True
  366. def _deactivate_adapter(self, lora_id: int):
  367. try:
  368. index = self.lora_index_to_id.index(lora_id)
  369. self.lora_index_to_id[index] = None
  370. except ValueError:
  371. pass
  372. def _set_long_lora_context(self, lora: LoRAModel):
  373. if self.long_lora_context is None:
  374. return
  375. if lora.scaling_factor is None:
  376. return
  377. if (lora.scaling_factor not in self.scaling_factor_to_offset):
  378. raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
  379. " has not been initialized.")
  380. offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
  381. if offsets:
  382. self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
  383. def _add_adapter(self, lora: LoRAModel):
  384. self._create_merged_loras_inplace(lora)
  385. self._registered_adapters[lora.id] = lora
  386. self._set_long_lora_context(lora)
  387. def pin_adapter(self, lora_id: int) -> bool:
  388. """Pin a LoRAModel in the manager cache."""
  389. raise NotImplementedError(
  390. "Pinning is not supported in LoRAModelManager."
  391. "Use LRUCacheLoRAModelManager for pinning") # type: ignore
  392. def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
  393. # update lora states
  394. self.punica_wrapper.update_metadata(
  395. mapping,
  396. self.lora_index_to_id,
  397. self.lora_slots + 1,
  398. self.vocab_size,
  399. self.lora_config.lora_extra_vocab_size,
  400. self.long_lora_context,
  401. )
  402. def remove_all_adapters(self):
  403. """Remove all LoRAModels from the manager."""
  404. self._registered_adapters.clear()
  405. self.lora_index_to_id = [None] * self.lora_slots
  406. self._active_adapters.clear()
  407. def _create_lora_modules(self):
  408. for module_name, module in self.model.named_modules(
  409. remove_duplicate=False):
  410. if isinstance(module, PPMissingLayer):
  411. continue
  412. if not self._match_target_modules(module_name):
  413. continue
  414. parts = module_name.split(".")[-1]
  415. packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
  416. new_module = replace_submodule(
  417. self.model, module_name,
  418. from_layer(module, self.lora_slots, self.lora_config,
  419. packed_moduled_lst, self.model.config))
  420. # LinearScalingRotaryEmbeddingWithLora is used to handle
  421. # long context lora. Register relevant metadata.
  422. if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
  423. self.long_lora_context = LongContextLoRAContext(
  424. new_module.scaling_factors, new_module.rotary_dim)
  425. self.scaling_factor_to_offset = \
  426. new_module.scaling_factor_to_offset
  427. # (yard1): TODO make this more robust
  428. # replace lm_head by A*B lora if needed
  429. if (("lm_head" in module_name)
  430. and not self.lora_config.enable_lora_modules_to_save):
  431. logits_processor_module = self.model.get_submodule(
  432. "logits_processor")
  433. new_module = replace_submodule(
  434. self.model, "logits_processor",
  435. from_layer_logits_processor(logits_processor_module,
  436. module, self.lora_slots,
  437. self.lora_config,
  438. self.model.config))
  439. self.register_module(module_name, new_module)
  440. self._register_packed_modules(module_name)
  441. # All lora layers share the same punica_wrapper based on reference.
  442. new_module.set_mapping(self.punica_wrapper)
  443. def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
  444. assert isinstance(module, BaseLayerWithLoRA)
  445. self.modules[module_name] = module
  446. def create_dummy_lora(
  447. self,
  448. lora_id: int,
  449. rank: int,
  450. scaling_factor: Optional[float],
  451. embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
  452. """Create zero-initialized LoRAModel for warmup."""
  453. model = LoRAModel(lora_id, rank, {}, scaling_factor)
  454. for module_name, module in self.model.named_modules():
  455. if not self._match_target_modules(module_name) or not isinstance(
  456. module, BaseLayerWithLoRA) or isinstance(
  457. module, LinearScalingRotaryEmbeddingWithLora):
  458. continue
  459. parts = module_name.split(".")
  460. if module_name not in self.packed_modules:
  461. assert embedding_modules is not None
  462. if parts[-1] in embedding_modules:
  463. input_dim = (module.base_layer.org_vocab_size +
  464. self.lora_config.lora_extra_vocab_size if
  465. hasattr(module.base_layer, "org_vocab_size")
  466. else module.base_layer.weight.shape[1])
  467. output_dim = module.base_layer.embedding_dim if hasattr(
  468. module.base_layer,
  469. "embedding_dim") else module.base_layer.weight.shape[0]
  470. embeddings_tensor_dim = (module.base_layer.embedding_dim if
  471. hasattr(module.base_layer,
  472. "embedding_dim") else
  473. module.base_layer.weight.shape[1])
  474. rank_ = None if isinstance(module,
  475. ModulesToSaveWrapper) else rank
  476. lora = LoRALayerWeights.create_dummy_lora_weights(
  477. module_name,
  478. input_dim,
  479. output_dim,
  480. rank_,
  481. module.dtype,
  482. "cpu",
  483. embeddings_tensor_dim=embeddings_tensor_dim)
  484. else:
  485. lora = LoRALayerWeights.create_dummy_lora_weights(
  486. module_name,
  487. module.lora_a_stacked.shape[-1],
  488. module.lora_b_stacked.shape[-2],
  489. rank,
  490. module.lora_a_stacked.dtype,
  491. "cpu",
  492. )
  493. lora.optimize()
  494. else:
  495. parts = module_name.split(".")
  496. replacements = self.packed_modules_mapping[parts[-1]]
  497. subloras: List[Optional["LoRALayerWeights"]] = []
  498. for i, r in enumerate(replacements):
  499. lora = LoRALayerWeights.create_dummy_lora_weights(
  500. module_name + "." + r,
  501. module.lora_a_stacked[i].shape[-1],
  502. module.lora_b_stacked[i].shape[-2],
  503. rank,
  504. module.lora_a_stacked[i].dtype,
  505. "cpu",
  506. )
  507. lora.optimize()
  508. subloras.append(lora)
  509. lora = PackedLoRALayerWeights.pack(subloras)
  510. model.loras[module_name] = lora
  511. return model
  512. def _match_target_modules(self, module_name: str):
  513. return any(
  514. re.match(
  515. r".*\.{target_module}$".format(target_module=target_module),
  516. module_name) or target_module == module_name
  517. for target_module in self.supported_lora_modules)
  518. def _register_packed_modules(self, module_full_name: str) -> None:
  519. parts = module_full_name.split(".")
  520. module_name = parts[-1]
  521. replacements = self.packed_modules_mapping.get(module_name, [])
  522. # When replacements is less than or equal to 1, it indicates that this
  523. # module is not a packed module.
  524. if len(replacements) <= 1:
  525. return
  526. prefix = ".".join(parts[:-1])
  527. self.packed_modules[module_full_name] = [
  528. prefix + "." + r if prefix else r for r in replacements
  529. ]
  530. def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
  531. for module_name, new_module_names in self.packed_modules.items():
  532. replacement_loras: List[Optional[LoRALayerWeights]] = []
  533. has_replacement = False
  534. for r in new_module_names:
  535. lora = lora_model.get_lora(r)
  536. replacement_loras.append(lora)
  537. if lora:
  538. has_replacement = True
  539. if not has_replacement:
  540. continue
  541. for i in range(len(replacement_loras)):
  542. if replacement_loras[i]:
  543. continue
  544. replacement_loras[i] = None
  545. lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
  546. replacement_loras)
  547. def deactivate_adapter(self, adapter_id: int) -> bool:
  548. return deactivate_adapter(adapter_id, self._active_adapters,
  549. self._deactivate_adapter)
  550. def add_adapter(self, adapter: LoRAModel) -> bool:
  551. logger.debug(
  552. f"Adding lora. Model id: {adapter.id}, "
  553. f"int id: {adapter.id}, "
  554. f"scaling factor: {adapter.scaling_factor}")
  555. return add_adapter(adapter, self._registered_adapters, self.capacity,
  556. self._add_adapter)
  557. def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
  558. self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
  559. self._set_adapter_mapping)
  560. def remove_adapter(self, adapter_id: int) -> bool:
  561. return remove_adapter(adapter_id, self._registered_adapters,
  562. self.deactivate_adapter)
  563. def list_adapters(self) -> Dict[int, Any]:
  564. return list_adapters(self._registered_adapters)
  565. def get_adapter(self, adapter_id: int) -> Optional[Any]:
  566. return get_adapter(adapter_id, self._registered_adapters)
  567. class LoRALRUCache(AdapterLRUCache[LoRAModel]):
  568. def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
  569. bool]):
  570. super().__init__(capacity, deactivate_lora_fn)
  571. class LRUCacheLoRAModelManager(LoRAModelManager):
  572. """A model manager that manages multiple LoRAs with LRU cache."""
  573. def __init__(
  574. self,
  575. model: nn.Module,
  576. max_num_seqs: int,
  577. max_num_batched_tokens: int,
  578. vocab_size: int,
  579. lora_config: LoRAConfig,
  580. ):
  581. super().__init__(model, max_num_seqs, max_num_batched_tokens,
  582. vocab_size, lora_config)
  583. self._registered_adapters: LoRALRUCache = LoRALRUCache(
  584. self.capacity, self.deactivate_adapter)
  585. self._active_adapters: LoRALRUCache = LoRALRUCache(
  586. self.lora_slots, self._deactivate_adapter)
  587. def list_adapters(self) -> Dict[int, LoRAModel]:
  588. """List all registered LoRAModels."""
  589. return dict(self._registered_adapters.cache)
  590. def add_adapter(self, lora: LoRAModel) -> bool:
  591. """Add a LoRAModel to the manager."""
  592. logger.debug(
  593. f"Adding lora. Model id: {lora.id}, "
  594. f"int id: {lora.id}, "
  595. f"scaling factor: {lora.scaling_factor}")
  596. if lora.id not in self._registered_adapters:
  597. self._add_adapter(lora)
  598. was_added = True
  599. else:
  600. # We always touch to update the LRU cache order
  601. self._registered_adapters.touch(lora.id)
  602. was_added = False
  603. return was_added
  604. def activate_adapter(
  605. self,
  606. lora_id: int,
  607. ) -> bool:
  608. if lora_id not in self._active_adapters and len(
  609. self._active_adapters) >= self.lora_slots:
  610. self._active_adapters.remove_oldest()
  611. result = super().activate_adapter(lora_id)
  612. # We always touch to update the LRU cache order
  613. self._active_adapters.touch(lora_id)
  614. return result
  615. def remove_oldest_adapter(self) -> bool:
  616. if len(self._registered_adapters) > 0:
  617. self._registered_adapters.remove_oldest()
  618. return True
  619. return False
  620. def pin_adapter(self, lora_id: int) -> bool:
  621. """Pin a LoRAModel in the manager cache."""
  622. self._pin_lora_in_cpu_cache(lora_id)
  623. self._pin_lora_in_gpu_cache(lora_id)
  624. return True
  625. def _pin_lora_in_cpu_cache(self, lora_id: int):
  626. try:
  627. self._registered_adapters.pin(lora_id)
  628. except ValueError as err:
  629. raise ValueError("Pinning failed. "
  630. f"LoRA {lora_id} is not registered.") from err
  631. def _pin_lora_in_gpu_cache(self, lora_id: int):
  632. if lora_id not in self._active_adapters:
  633. # move lora to gpu if not already active
  634. self.activate_adapter(lora_id)
  635. self._active_adapters.pin(lora_id)
  636. def create_lora_manager(
  637. model: nn.Module,
  638. max_num_seqs: int,
  639. max_num_batched_tokens: int,
  640. vocab_size: int,
  641. lora_config: LoRAConfig,
  642. lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
  643. **kwargs) -> LoRAModelManager:
  644. """Create a LoRA adapter for a given model."""
  645. if not hasattr(model, "supported_lora_modules"):
  646. raise ValueError(f"Model {type(model)} is not supported for LoRA.")
  647. lora_manager = lora_manager_cls(
  648. model=model,
  649. max_num_seqs=max_num_seqs,
  650. max_num_batched_tokens=max_num_batched_tokens,
  651. vocab_size=vocab_size,
  652. lora_config=lora_config,
  653. **kwargs)
  654. return lora_manager