models.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731
  1. import copy
  2. import json
  3. import math
  4. import os
  5. import re
  6. from dataclasses import dataclass, field
  7. from typing import Any, Callable, Dict, List, Optional, Type
  8. import safetensors.torch
  9. import torch
  10. from loguru import logger
  11. from torch import nn
  12. from aphrodite.adapter_commons.models import (AdapterLRUCache, AdapterModel,
  13. AdapterModelManager)
  14. from aphrodite.adapter_commons.utils import (add_adapter, deactivate_adapter,
  15. get_adapter, list_adapters,
  16. remove_adapter,
  17. set_adapter_mapping)
  18. from aphrodite.common.config import LoRAConfig
  19. from aphrodite.common.utils import is_pin_memory_available
  20. from aphrodite.lora.layers import (BaseLayerWithLoRA,
  21. LinearScalingRotaryEmbeddingWithLora,
  22. LoRAMapping)
  23. from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
  24. from aphrodite.lora.punica import PunicaWrapper
  25. from aphrodite.lora.utils import (from_layer, from_layer_logits_processor,
  26. parse_fine_tuned_lora_name,
  27. replace_submodule)
  28. from aphrodite.modeling.models.interfaces import SupportsLoRA
  29. from aphrodite.modeling.models.utils import PPMissingLayer
  30. _GLOBAL_LORA_ID = 0
  31. @dataclass
  32. class LongContextLoRAContext:
  33. """Context for lora adapters that support long context."""
  34. # The scaling factors to support long context lora fine tuned models.
  35. scaling_factors: List[float]
  36. # dimension to apply rotary embedding.
  37. rot_dim: int
  38. # offsets to the sin_cos_cache for each lora_id loaded.
  39. # This value is dynamically modified.
  40. offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
  41. def get_lora_id():
  42. global _GLOBAL_LORA_ID
  43. _GLOBAL_LORA_ID += 1
  44. return _GLOBAL_LORA_ID
  45. class LoRAModel(AdapterModel):
  46. """A LoRA fine-tuned model."""
  47. def __init__(
  48. self,
  49. lora_model_id: int,
  50. rank: int,
  51. loras: Dict[str, LoRALayerWeights],
  52. scaling_factor: Optional[float] = None,
  53. ) -> None:
  54. """
  55. Args:
  56. lora_model_id: The integer id for the lora model.
  57. rank: lora rank.
  58. loras: module name -> weights for lora-replaced layers.
  59. scaling_factor: Scaling factor to support long context lora model.
  60. None if the lora is not tuned for long context support.
  61. """
  62. self.id = lora_model_id
  63. # Scaling factor for long context lora model. None if it is not
  64. # fine tuned for the long context.
  65. self.scaling_factor = scaling_factor
  66. assert (lora_model_id >
  67. 0), f"a valid lora id should be greater than 0, got {self.id}"
  68. self.rank = rank
  69. self.loras: Dict[str, LoRALayerWeights] = loras
  70. def clone(self, lora_model_id: int) -> "LoRAModel":
  71. """Return a copy of the object with different ids.
  72. Will share the underlying tensors."""
  73. return self.__class__(
  74. lora_model_id,
  75. rank=self.rank,
  76. loras=self.loras.copy(),
  77. )
  78. @property
  79. def extra_vocab_size(self) -> int:
  80. return max(lora.extra_vocab_size
  81. for lora in self.loras.values()) if self.loras else 0
  82. def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
  83. """Get LoRA for a given module by name"""
  84. return self.loras.get(module_name, None)
  85. # (yard1): TODO see if we can derive target_embedding_padding automatically
  86. @classmethod
  87. def from_lora_tensors(
  88. cls,
  89. lora_model_id: int,
  90. rank: int,
  91. lora_alpha: int,
  92. tensors: Dict[str, torch.Tensor],
  93. device: str = "cuda",
  94. dtype: Optional[torch.dtype] = None,
  95. embeddings: Optional[Dict[str, torch.Tensor]] = None,
  96. target_embedding_padding: Optional[int] = None,
  97. scaling_factor: Optional[float] = None,
  98. embedding_modules: Optional[Dict[str, str]] = None,
  99. embedding_padding_modules: Optional[List[str]] = None,
  100. ) -> "LoRAModel":
  101. """Create a LoRAModel from a dictionary of tensors."""
  102. pin_memory = str(device) == "cpu" and is_pin_memory_available()
  103. loras: Dict[str, LoRALayerWeights] = {}
  104. skipped_count = 0
  105. for tensor_name, tensor in tensors.items():
  106. parsed = parse_fine_tuned_lora_name(tensor_name)
  107. if parsed is None:
  108. skipped_count += 1
  109. print(f"\rSkipping {skipped_count} "
  110. "unsupported LoRA weight tensors... Latest: "
  111. f"{tensor_name}", end="", flush=True)
  112. continue
  113. module_name, is_lora_a = parsed
  114. if module_name not in loras:
  115. lora_embeddings_tensor = None
  116. if embeddings:
  117. assert embedding_modules is not None
  118. embeddings_module = next(
  119. (k for k in embedding_modules if k in module_name),
  120. None)
  121. if embeddings_module:
  122. lora_embeddings_tensor = embeddings[
  123. embedding_modules[embeddings_module]].to(
  124. device=device, dtype=dtype)
  125. if pin_memory:
  126. lora_embeddings_tensor = (
  127. lora_embeddings_tensor.pin_memory())
  128. loras[module_name] = LoRALayerWeights(module_name, rank,
  129. lora_alpha, None, None,
  130. lora_embeddings_tensor)
  131. if is_lora_a:
  132. loras[module_name].lora_a = tensor.to(device=device,
  133. dtype=dtype).t()
  134. if pin_memory:
  135. loras[module_name].lora_a = loras[
  136. module_name].lora_a.pin_memory()
  137. else:
  138. loras[module_name].lora_b = tensor.to(device=device,
  139. dtype=dtype).t()
  140. assert embedding_padding_modules is not None
  141. if any(name in module_name
  142. for name in embedding_padding_modules
  143. ) and target_embedding_padding is not None:
  144. lora_b = loras[module_name].lora_b
  145. assert target_embedding_padding >= lora_b.shape[1]
  146. addition = target_embedding_padding - lora_b.shape[1]
  147. loras[module_name].lora_b = torch.nn.functional.pad(
  148. lora_b, (0, addition))
  149. if pin_memory:
  150. loras[module_name].lora_b = loras[
  151. module_name].lora_b.pin_memory()
  152. for lora in loras.values():
  153. lora.optimize()
  154. if skipped_count > 0:
  155. # Print final count and move to next line
  156. print(f"\rSkipped {skipped_count} unsupported LoRA weight tensors.")
  157. return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
  158. @classmethod
  159. def from_local_checkpoint(
  160. cls,
  161. lora_dir: str,
  162. expected_lora_modules: List[str],
  163. *,
  164. max_position_embeddings: Optional[int] = None,
  165. lora_model_id: Optional[int] = None,
  166. device: str = "cuda",
  167. dtype: Optional[torch.dtype] = None,
  168. target_embedding_padding: Optional[int] = None,
  169. embedding_modules: Optional[Dict[str, str]] = None,
  170. embedding_padding_modules: Optional[List[str]] = None,
  171. ) -> "LoRAModel":
  172. """Create a LoRAModel from a local checkpoint.
  173. Args:
  174. lora_dir: The local path that has lora data.
  175. expected_lora_modules: Name of modules that are expected to be
  176. replaced by lora.
  177. max_position_embeddings: Max position embedding length. Used to
  178. scaling the largest context length. If None, the lora model's
  179. context length is not scaled.
  180. lora_model_id: Lora model id. If not given, automatically set by
  181. a global counter.
  182. device: Device where the lora model is loaded.
  183. dtype: dtype of the lora model weights.
  184. Returns:
  185. Loaded LoRA Model.
  186. """
  187. lora_config_path = os.path.join(lora_dir, "adapter_config.json")
  188. lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
  189. lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
  190. new_embeddings_tensor_path = os.path.join(
  191. lora_dir, "new_embeddings.safetensors")
  192. new_embeddings_bin_file_path = os.path.join(lora_dir,
  193. "new_embeddings.bin")
  194. with open(lora_config_path) as f:
  195. config = json.load(f)
  196. if os.path.isfile(lora_tensor_path):
  197. tensors: Dict[str, torch.Tensor] = {}
  198. # Find unexpected modules.
  199. # Use safetensor key as a source of truth to find expected modules.
  200. # in peft if you have target_modules A, B, C and C does not exist
  201. # in the model it won’t error and model will be trained with A, B
  202. # loraified. C won’t exist in the safetensor but it will exist in
  203. # the target_modules of the adapter_config.json.
  204. unexpected_modules = []
  205. skipped_count = 0
  206. with safetensors.safe_open(lora_tensor_path,
  207. framework="pt") as f: # type: ignore
  208. for lora_module in f.keys(): # noqa
  209. parsed = parse_fine_tuned_lora_name(lora_module)
  210. if parsed is None:
  211. skipped_count += 1
  212. print(f"\rSkipping {skipped_count} unsupported LoRA "
  213. "weight tensors... Latest: "
  214. f"{lora_module}", end="", flush=True)
  215. continue
  216. module_name, _ = parsed
  217. part_name = module_name.split(".")[-1]
  218. if part_name not in expected_lora_modules:
  219. unexpected_modules.append(module_name)
  220. if skipped_count > 0:
  221. print(f"\rSkipped {skipped_count} unsupported LoRA "
  222. "weight tensors.")
  223. if unexpected_modules:
  224. raise ValueError(
  225. f"While loading {lora_dir}, expected"
  226. f" target modules in {expected_lora_modules}"
  227. f" but received {unexpected_modules}."
  228. f" Please verify that the loaded LoRA module is correct"
  229. )
  230. # Load tensors if there are only expected modules.
  231. for module in f.keys(): # noqa
  232. tensors[module] = f.get_tensor(module)
  233. elif os.path.isfile(lora_bin_file_path):
  234. # When a bin file is provided, we rely on config to find unexpected
  235. # modules.
  236. unexpected_modules = []
  237. target_modules = config["target_modules"]
  238. skipped_count = 0
  239. for module in target_modules:
  240. # Compatible with more modules,
  241. # such as:layers.11.self_attn.k_proj
  242. part_name = module.split(".")[-1]
  243. if part_name not in expected_lora_modules:
  244. skipped_count += 1
  245. print(f"\rSkipping {skipped_count} unexpected modules... "
  246. f"Latest: {module}", end="", flush=True)
  247. unexpected_modules.append(module)
  248. if skipped_count > 0:
  249. print(f"\rSkipped {skipped_count} unexpected modules.")
  250. # loaded lora's target modules must be a subset of
  251. # expected_lora_modules. It is not reliable. See
  252. # https://github.com/vllm-project/vllm/pull/5909. But there's no
  253. # other better mechanism.
  254. if unexpected_modules:
  255. raise ValueError(
  256. f"While loading {lora_dir}, expected"
  257. f" target modules in {expected_lora_modules}"
  258. f" but received {unexpected_modules}."
  259. f" Please verify that the loaded LoRA module is correct")
  260. tensors = torch.load(lora_bin_file_path, map_location=device)
  261. else:
  262. raise ValueError(f"{lora_dir} doesn't contain tensors")
  263. embeddings = None
  264. if os.path.isfile(new_embeddings_tensor_path):
  265. embeddings = safetensors.torch.load_file(
  266. new_embeddings_tensor_path)
  267. elif os.path.isfile(new_embeddings_bin_file_path):
  268. embeddings = torch.load(new_embeddings_bin_file_path,
  269. map_location=device)
  270. rank = config["r"]
  271. lora_alpha = config["lora_alpha"] * math.sqrt(rank) if config.get(
  272. "use_rslora", False) else config["lora_alpha"]
  273. context_length = config.get("context_length", None)
  274. scaling_factor = None
  275. if context_length:
  276. if max_position_embeddings is None:
  277. max_position_embeddings = context_length
  278. scaling_factor = float(
  279. math.ceil(context_length / max_position_embeddings))
  280. return cls.from_lora_tensors(
  281. lora_model_id=get_lora_id()
  282. if lora_model_id is None else lora_model_id,
  283. rank=rank,
  284. lora_alpha=lora_alpha,
  285. tensors=tensors,
  286. device=device,
  287. dtype=dtype,
  288. embeddings=embeddings,
  289. target_embedding_padding=target_embedding_padding,
  290. scaling_factor=scaling_factor,
  291. embedding_modules=embedding_modules,
  292. embedding_padding_modules=embedding_padding_modules,
  293. )
  294. class LoRAModelManager(AdapterModelManager):
  295. """A manager that manages multiple LoRA-fine-tuned models."""
  296. def __init__(
  297. self,
  298. model: SupportsLoRA,
  299. max_num_seqs: int,
  300. max_num_batched_tokens: int,
  301. vocab_size: int,
  302. lora_config: LoRAConfig,
  303. ):
  304. """Create a LoRAModelManager and adapter for a given model.
  305. Args:
  306. model: the model to be adapted.
  307. max_num_seqs: the maximum number of sequences model can run in a
  308. single batch.
  309. max_num_batched_tokens: the maximum number of tokens model can run
  310. in a single batch.
  311. vocab_size: the vocab size of the model.
  312. lora_config: the LoRA configuration.
  313. """
  314. self.lora_config = lora_config
  315. self.max_num_seqs = max_num_seqs
  316. assert self.capacity >= self.lora_slots
  317. self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
  318. self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
  319. self.vocab_size = vocab_size
  320. self.long_lora_context: Optional[LongContextLoRAContext] = None
  321. self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
  322. max_batches=self.max_num_seqs,
  323. device="cuda")
  324. # Scaling factor -> offset to the sin_cos_cache to it.
  325. # Used for long context lora.
  326. self.scaling_factor_to_offset: Dict[float, int] = {}
  327. super().__init__(model)
  328. if hasattr(self.model, "supported_lora_modules"):
  329. self.supported_lora_modules = copy.deepcopy(
  330. self.model.supported_lora_modules)
  331. if lora_config.long_lora_scaling_factors:
  332. # We need to replace rotary emb layer to do batch computation
  333. # for long lora.
  334. self.supported_lora_modules.append("rotary_emb")
  335. self.packed_modules_mapping = copy.deepcopy(
  336. self.model.packed_modules_mapping)
  337. self.packed_modules: Dict[str, List[str]] = {}
  338. self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
  339. # Dict instead of a Set for compatibility with LRUCache.
  340. self._last_mapping: Optional[LoRAMapping] = None
  341. self._create_lora_modules()
  342. self.model.lora_manager = self
  343. self.adapter_type = 'LoRa'
  344. @property
  345. def capacity(self) -> int:
  346. return self.lora_config.max_cpu_loras
  347. @property
  348. def lora_slots(self) -> int:
  349. return self.lora_config.max_loras
  350. @property
  351. def adapter_slots(self) -> int:
  352. return self.lora_slots
  353. def activate_adapter(
  354. self,
  355. lora_id: int,
  356. ) -> bool:
  357. """Move LoRA into a GPU buffer to be used in the forward pass."""
  358. if lora_id in self._active_adapters:
  359. return False
  360. first_free_slot = next(
  361. ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
  362. if lora_id is None), None)
  363. if first_free_slot is None:
  364. raise ValueError("No free lora slots")
  365. index, _ = first_free_slot
  366. self._active_adapters[lora_id] = None
  367. lora_model = self._registered_adapters[lora_id]
  368. logger.debug(f"Activating LoRA. int id: {lora_model.id}, "
  369. f"slot index: {index}")
  370. self.lora_index_to_id[index] = lora_model.id
  371. for module_name, module in self.modules.items():
  372. module_lora = lora_model.get_lora(module_name)
  373. if module_lora:
  374. module_lora.optimize()
  375. module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
  376. module_lora.embeddings_tensor)
  377. else:
  378. module.reset_lora(index)
  379. return True
  380. def _deactivate_adapter(self, lora_id: int):
  381. try:
  382. index = self.lora_index_to_id.index(lora_id)
  383. self.lora_index_to_id[index] = None
  384. except ValueError:
  385. pass
  386. def _set_long_lora_context(self, lora: LoRAModel):
  387. if self.long_lora_context is None:
  388. return
  389. if lora.scaling_factor is None:
  390. return
  391. if (lora.scaling_factor not in self.scaling_factor_to_offset):
  392. raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
  393. " has not been initialized.")
  394. offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
  395. if offsets:
  396. self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
  397. def _add_adapter(self, lora: LoRAModel):
  398. self._create_merged_loras_inplace(lora)
  399. self._registered_adapters[lora.id] = lora
  400. self._set_long_lora_context(lora)
  401. def pin_adapter(self, lora_id: int) -> bool:
  402. """Pin a LoRAModel in the manager cache."""
  403. raise NotImplementedError(
  404. "Pinning is not supported in LoRAModelManager."
  405. "Use LRUCacheLoRAModelManager for pinning") # type: ignore
  406. def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
  407. # update lora states
  408. self.punica_wrapper.update_metadata(
  409. mapping,
  410. self.lora_index_to_id,
  411. self.lora_slots + 1,
  412. self.vocab_size,
  413. self.lora_config.lora_extra_vocab_size,
  414. self.long_lora_context,
  415. )
  416. def remove_all_adapters(self):
  417. """Remove all LoRAModels from the manager."""
  418. self._registered_adapters.clear()
  419. self.lora_index_to_id = [None] * self.lora_slots
  420. self._active_adapters.clear()
  421. def _create_lora_modules(self):
  422. for module_name, module in self.model.named_modules(
  423. remove_duplicate=False):
  424. if isinstance(module, PPMissingLayer):
  425. continue
  426. if not self._match_target_modules(module_name):
  427. continue
  428. parts = module_name.split(".")[-1]
  429. packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
  430. new_module = replace_submodule(
  431. self.model, module_name,
  432. from_layer(module, self.lora_slots, self.lora_config,
  433. packed_moduled_lst, self.model.config))
  434. # LinearScalingRotaryEmbeddingWithLora is used to handle
  435. # long context lora. Register relevant metadata.
  436. if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
  437. self.long_lora_context = LongContextLoRAContext(
  438. new_module.scaling_factors, new_module.rotary_dim)
  439. self.scaling_factor_to_offset = \
  440. new_module.scaling_factor_to_offset
  441. # (yard1): TODO make this more robust
  442. if "lm_head" in module_name:
  443. logits_processor_module = self.model.get_submodule(
  444. "logits_processor")
  445. new_module = replace_submodule(
  446. self.model, "logits_processor",
  447. from_layer_logits_processor(logits_processor_module,
  448. module, self.lora_slots,
  449. self.lora_config,
  450. self.model.config))
  451. self.register_module(module_name, new_module)
  452. self._register_packed_modules(module_name)
  453. # All lora layers share the same punica_wrapper based on reference.
  454. new_module.set_mapping(self.punica_wrapper)
  455. def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
  456. assert isinstance(module, BaseLayerWithLoRA)
  457. self.modules[module_name] = module
  458. def create_dummy_lora(
  459. self,
  460. lora_id: int,
  461. rank: int,
  462. scaling_factor: Optional[float],
  463. embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
  464. """Create zero-initialized LoRAModel for warmup."""
  465. model = LoRAModel(lora_id, rank, {}, scaling_factor)
  466. for module_name, module in self.model.named_modules():
  467. if not self._match_target_modules(module_name) or not isinstance(
  468. module, BaseLayerWithLoRA) or isinstance(
  469. module, LinearScalingRotaryEmbeddingWithLora):
  470. continue
  471. parts = module_name.split(".")
  472. if module_name not in self.packed_modules:
  473. assert embedding_modules is not None
  474. if parts[-1] in embedding_modules:
  475. input_dim = (module.base_layer.org_vocab_size +
  476. self.lora_config.lora_extra_vocab_size if
  477. hasattr(module.base_layer, "org_vocab_size")
  478. else module.base_layer.weight.shape[1])
  479. output_dim = module.base_layer.embedding_dim if hasattr(
  480. module.base_layer,
  481. "embedding_dim") else module.base_layer.weight.shape[0]
  482. embeddings_tensor_dim = (module.base_layer.embedding_dim if
  483. hasattr(module.base_layer,
  484. "embedding_dim") else
  485. module.base_layer.weight.shape[1])
  486. lora = LoRALayerWeights.create_dummy_lora_weights(
  487. module_name,
  488. input_dim,
  489. output_dim,
  490. rank,
  491. module.lora_a_stacked.dtype,
  492. "cpu",
  493. embeddings_tensor_dim=embeddings_tensor_dim)
  494. else:
  495. lora = LoRALayerWeights.create_dummy_lora_weights(
  496. module_name,
  497. module.lora_a_stacked.shape[-1],
  498. module.lora_b_stacked.shape[-2],
  499. rank,
  500. module.lora_a_stacked.dtype,
  501. "cpu",
  502. )
  503. lora.optimize()
  504. else:
  505. parts = module_name.split(".")
  506. replacements = self.packed_modules_mapping[parts[-1]]
  507. subloras: List[Optional["LoRALayerWeights"]] = []
  508. for i, r in enumerate(replacements):
  509. lora = LoRALayerWeights.create_dummy_lora_weights(
  510. module_name + "." + r,
  511. module.lora_a_stacked[i].shape[-1],
  512. module.lora_b_stacked[i].shape[-2],
  513. rank,
  514. module.lora_a_stacked[i].dtype,
  515. "cpu",
  516. )
  517. lora.optimize()
  518. subloras.append(lora)
  519. lora = PackedLoRALayerWeights.pack(subloras)
  520. model.loras[module_name] = lora
  521. return model
  522. def _match_target_modules(self, module_name: str):
  523. return any(
  524. re.match(
  525. r".*\.{target_module}$".format(target_module=target_module),
  526. module_name) or target_module == module_name
  527. for target_module in self.supported_lora_modules)
  528. def _register_packed_modules(self, module_full_name: str) -> None:
  529. parts = module_full_name.split(".")
  530. module_name = parts[-1]
  531. replacements = self.packed_modules_mapping.get(module_name, [])
  532. # When replacements is less than or equal to 1, it indicates that this
  533. # module is not a packed module.
  534. if len(replacements) <= 1:
  535. return
  536. prefix = ".".join(parts[:-1])
  537. self.packed_modules[module_full_name] = [
  538. prefix + "." + r if prefix else r for r in replacements
  539. ]
  540. def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
  541. for module_name, new_module_names in self.packed_modules.items():
  542. replacement_loras: List[Optional[LoRALayerWeights]] = []
  543. has_replacement = False
  544. for r in new_module_names:
  545. lora = lora_model.get_lora(r)
  546. replacement_loras.append(lora)
  547. if lora:
  548. has_replacement = True
  549. if not has_replacement:
  550. continue
  551. for i in range(len(replacement_loras)):
  552. if replacement_loras[i]:
  553. continue
  554. replacement_loras[i] = None
  555. lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
  556. replacement_loras)
  557. def deactivate_adapter(self, adapter_id: int) -> bool:
  558. return deactivate_adapter(adapter_id, self._active_adapters,
  559. self._deactivate_adapter)
  560. def add_adapter(self, adapter: LoRAModel) -> bool:
  561. logger.debug(
  562. f"Adding lora. Model id: {adapter.id}, "
  563. f"int id: {adapter.id}, "
  564. f"scaling factor: {adapter.scaling_factor}")
  565. return add_adapter(adapter, self._registered_adapters, self.capacity,
  566. self._add_adapter)
  567. def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
  568. self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
  569. self._set_adapter_mapping)
  570. def remove_adapter(self, adapter_id: int) -> bool:
  571. return remove_adapter(adapter_id, self._registered_adapters,
  572. self.deactivate_adapter)
  573. def list_adapters(self) -> Dict[int, Any]:
  574. return list_adapters(self._registered_adapters)
  575. def get_adapter(self, adapter_id: int) -> Optional[Any]:
  576. return get_adapter(adapter_id, self._registered_adapters)
  577. class LoRALRUCache(AdapterLRUCache[LoRAModel]):
  578. def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
  579. bool]):
  580. super().__init__(capacity, deactivate_lora_fn)
  581. class LRUCacheLoRAModelManager(LoRAModelManager):
  582. """A model manager that manages multiple LoRAs with LRU cache."""
  583. def __init__(
  584. self,
  585. model: nn.Module,
  586. max_num_seqs: int,
  587. max_num_batched_tokens: int,
  588. vocab_size: int,
  589. lora_config: LoRAConfig,
  590. ):
  591. super().__init__(model, max_num_seqs, max_num_batched_tokens,
  592. vocab_size, lora_config)
  593. self._registered_adapters: LoRALRUCache = LoRALRUCache(
  594. self.capacity, self.deactivate_adapter)
  595. self._active_adapters: LoRALRUCache = LoRALRUCache(
  596. self.lora_slots, self._deactivate_adapter)
  597. def list_adapters(self) -> Dict[int, LoRAModel]:
  598. """List all registered LoRAModels."""
  599. return dict(self._registered_adapters.cache)
  600. def add_adapter(self, lora: LoRAModel) -> bool:
  601. """Add a LoRAModel to the manager."""
  602. logger.debug(
  603. f"Adding lora. Model id: {lora.id}, "
  604. f"int id: {lora.id}, "
  605. f"scaling factor: {lora.scaling_factor}")
  606. if lora.id not in self._registered_adapters:
  607. self._add_adapter(lora)
  608. was_added = True
  609. else:
  610. # We always touch to update the LRU cache order
  611. self._registered_adapters.touch(lora.id)
  612. was_added = False
  613. return was_added
  614. def activate_adapter(
  615. self,
  616. lora_id: int,
  617. ) -> bool:
  618. if lora_id not in self._active_adapters and len(
  619. self._active_adapters) >= self.lora_slots:
  620. self._active_adapters.remove_oldest()
  621. result = super().activate_adapter(lora_id)
  622. # We always touch to update the LRU cache order
  623. self._active_adapters.touch(lora_id)
  624. return result
  625. def remove_oldest_adapter(self) -> bool:
  626. if len(self._registered_adapters) > 0:
  627. self._registered_adapters.remove_oldest()
  628. return True
  629. return False
  630. def pin_adapter(self, lora_id: int) -> bool:
  631. """Pin a LoRAModel in the manager cache."""
  632. self._pin_lora_in_cpu_cache(lora_id)
  633. self._pin_lora_in_gpu_cache(lora_id)
  634. return True
  635. def _pin_lora_in_cpu_cache(self, lora_id: int):
  636. try:
  637. self._registered_adapters.pin(lora_id)
  638. except ValueError as err:
  639. raise ValueError("Pinning failed. "
  640. f"LoRA {lora_id} is not registered.") from err
  641. def _pin_lora_in_gpu_cache(self, lora_id: int):
  642. if lora_id not in self._active_adapters:
  643. # move lora to gpu if not already active
  644. self.activate_adapter(lora_id)
  645. self._active_adapters.pin(lora_id)
  646. def create_lora_manager(
  647. model: nn.Module,
  648. max_num_seqs: int,
  649. max_num_batched_tokens: int,
  650. vocab_size: int,
  651. lora_config: LoRAConfig,
  652. lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
  653. **kwargs) -> LoRAModelManager:
  654. """Create a LoRA adapter for a given model."""
  655. if not hasattr(model, "supported_lora_modules"):
  656. raise ValueError(f"Model {type(model)} is not supported for LoRA.")
  657. lora_manager = lora_manager_cls(
  658. model=model,
  659. max_num_seqs=max_num_seqs,
  660. max_num_batched_tokens=max_num_batched_tokens,
  661. vocab_size=vocab_size,
  662. lora_config=lora_config,
  663. **kwargs)
  664. return lora_manager