models.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798
  1. import copy
  2. import json
  3. import math
  4. import os
  5. import re
  6. from dataclasses import dataclass, field
  7. from typing import Callable, Dict, List, Optional, Tuple, Type, Union
  8. import safetensors.torch
  9. import torch
  10. from loguru import logger
  11. from torch import nn
  12. from aphrodite.common.config import LoRAConfig
  13. from aphrodite.common.utils import LRUCache, is_pin_memory_available
  14. from aphrodite.lora.layers import (BaseLayerWithLoRA,
  15. LinearScalingRotaryEmbeddingWithLora,
  16. LoRAMapping)
  17. from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
  18. from aphrodite.lora.utils import (from_layer, from_layer_logits_processor,
  19. parse_fine_tuned_lora_name,
  20. replace_submodule)
  21. _GLOBAL_LORA_ID = 0
  22. @dataclass
  23. class LongContextLoRAContext:
  24. """Context for lora adapters that support long context."""
  25. # The scaling factors to support long context lora fine tuned models.
  26. scaling_factors: List[float]
  27. # dimension to apply rotary embedding.
  28. rot_dim: int
  29. # offsets to the sin_cos_cache for each lora_id loaded.
  30. # This value is dynamically modified.
  31. offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
  32. def convert_mapping(
  33. mapping: LoRAMapping,
  34. lora_index_to_id: List[Optional[int]],
  35. max_loras: int,
  36. vocab_size: int,
  37. extra_vocab_size: int,
  38. long_lora_context: Optional[LongContextLoRAContext] = None,
  39. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
  40. Optional[torch.Tensor], List[int]]:
  41. """Converts LoRAMapping to index tensors.
  42. Args:
  43. mapping: LoRAMapping mapping rows in a batch to LoRA ids.
  44. lora_index_to_id: List mapping LoRA ids to LoRA indices.
  45. max_loras: Maximum number of LoRAs.
  46. vocab_size: Model vocab size.
  47. extra_vocab_size: Extra vocab size each LoRA can have.
  48. long_lora_context: Passed if there are long context lora in a batch.
  49. Returns:
  50. A tuple of tensors:
  51. base_indices: Tensor of shape [batch_size] mapping batch rows to
  52. LoRA indices.
  53. sampler_indices: Tensor of shape [batch_size] mapping requests to
  54. LoRA indices for sampler. For generation, this will be the
  55. same as base_indicies. For prefill, this will map requests
  56. to LoRA indices.
  57. sampler_indices_padded: Tensor of shape [batch_size] mapping
  58. requests to LoRA indices for sampler with padding.
  59. Same as sampler_indicies, but -1 is replaced with
  60. max_loras.
  61. embeddings_indices: Tensor of shape [2, batch_size] mapping
  62. requests to embedding indices. First row is for embeddings
  63. added by the LoRAs, second row is for the LoRA.lora_a
  64. embeddings.
  65. long_lora_indices: Tensor of shape [batch_size] mapping
  66. requests to RoPE offsets and rot dims for long LoRAs.
  67. None if long context lora doesn't exist.
  68. indices_len: List of lengths of the above tensors.
  69. Used to index into each tensor. It contains length for
  70. (base_indices, sampler_indices, sampler_indices_padded,
  71. embeddings_indices, long_lora_indices). If long_lora doesn't
  72. exist, it only contains first 4 entries.
  73. """
  74. index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
  75. embedding_indices = index_mapping_indices.copy()
  76. lora_indices = index_mapping_indices.copy()
  77. long_lora_offsets: Optional[torch.Tensor] = None
  78. if long_lora_context:
  79. long_lora_offsets = torch.zeros(len(index_mapping_indices),
  80. device="cuda",
  81. dtype=torch.long)
  82. prompt_mapping: List[int] = [
  83. lora_index_to_id.index(x) if x > 0 else -1
  84. for x in mapping.prompt_mapping
  85. ]
  86. lora_idx = None
  87. for i in range(len(index_mapping_indices)):
  88. # TODO index can be slow. optimize
  89. lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
  90. if index_mapping_indices[i] > 0 else -1)
  91. embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
  92. lora_indices[i] = lora_idx
  93. if long_lora_context:
  94. assert long_lora_offsets is not None
  95. lora_offset: int = long_lora_context.offsets_by_lora_id.get(
  96. index_mapping_indices[i], 0)
  97. long_lora_offsets[i] = lora_offset
  98. # SANG-TODO
  99. # index_mapping_indices[i] = i
  100. indices_list: List[Union[List[int], torch.Tensor]] = [
  101. index_mapping_indices, lora_indices, embedding_indices
  102. ]
  103. if long_lora_context:
  104. assert long_lora_offsets is not None
  105. indices_list.append(long_lora_offsets)
  106. indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
  107. prompt_mapping_tensor = torch.tensor(prompt_mapping,
  108. device="cuda",
  109. dtype=torch.long)
  110. embeddings_indices = torch.stack([
  111. indices[2] * extra_vocab_size,
  112. indices[2] * (vocab_size + extra_vocab_size)
  113. ])
  114. embeddings_indices[embeddings_indices == -1] = max_loras - 1
  115. base_indices = indices[1]
  116. sampler_indices = prompt_mapping_tensor
  117. sampler_indices_padded = sampler_indices.clone()
  118. sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
  119. sampler_indices_padded = (
  120. torch.arange(
  121. 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
  122. (sampler_indices_padded * len(sampler_indices_padded)))
  123. long_lora_indices = None
  124. long_lora_indices_len: Optional[int] = None
  125. if long_lora_context:
  126. long_lora_indices = indices[3]
  127. long_lora_indices_len = long_lora_indices.shape[-1]
  128. # Contain length of indices tensors. Used to index into each tensor.
  129. indices_len = [
  130. base_indices.shape[-1], sampler_indices.shape[-1],
  131. sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
  132. ]
  133. if long_lora_indices_len is not None:
  134. indices_len.append(long_lora_indices_len)
  135. return (base_indices, sampler_indices, sampler_indices_padded,
  136. embeddings_indices, long_lora_indices, indices_len)
  137. def get_lora_id():
  138. global _GLOBAL_LORA_ID
  139. _GLOBAL_LORA_ID += 1
  140. return _GLOBAL_LORA_ID
  141. class LoRAModel:
  142. """A LoRA fine-tuned model."""
  143. def __init__(
  144. self,
  145. lora_model_id: int,
  146. rank: int,
  147. loras: Dict[str, LoRALayerWeights],
  148. scaling_factor: Optional[float] = None,
  149. ) -> None:
  150. """
  151. Args:
  152. lora_model_id: The integer id for the lora model.
  153. rank: lora rank.
  154. loras: module name -> weights for lora-replaced layers.
  155. scaling_factor: Scaling factor to support long context lora model.
  156. None if the lora is not tuned for long context support.
  157. """
  158. self.id = lora_model_id
  159. # Scaling factor for long context lora model. None if it is not
  160. # fine tuned for the long context.
  161. self.scaling_factor = scaling_factor
  162. assert (lora_model_id >
  163. 0), f"a valid lora id should be greater than 0, got {self.id}"
  164. self.rank = rank
  165. self.loras: Dict[str, LoRALayerWeights] = loras
  166. def clone(self, lora_model_id: int) -> "LoRAModel":
  167. """Return a copy of the object with different ids.
  168. Will share the underlying tensors."""
  169. return self.__class__(
  170. lora_model_id,
  171. rank=self.rank,
  172. loras=self.loras.copy(),
  173. )
  174. @property
  175. def extra_vocab_size(self) -> int:
  176. return max(lora.extra_vocab_size
  177. for lora in self.loras.values()) if self.loras else 0
  178. def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
  179. """Get LoRA for a given module by name"""
  180. return self.loras.get(module_name, None)
  181. # (yard1): TODO see if we can derive target_embedding_padding automatically
  182. @classmethod
  183. def from_lora_tensors(
  184. cls,
  185. lora_model_id: int,
  186. rank: int,
  187. lora_alpha: int,
  188. tensors: Dict[str, torch.Tensor],
  189. device: str = "cuda",
  190. dtype: Optional[torch.dtype] = None,
  191. embeddings: Optional[Dict[str, torch.Tensor]] = None,
  192. target_embedding_padding: Optional[int] = None,
  193. scaling_factor: Optional[float] = None,
  194. embedding_modules: Optional[Dict[str, str]] = None,
  195. embedding_padding_modules: Optional[List[str]] = None,
  196. ) -> "LoRAModel":
  197. """Create a LoRAModel from a dictionary of tensors."""
  198. pin_memory = str(device) == "cpu" and is_pin_memory_available()
  199. loras: Dict[str, LoRALayerWeights] = {}
  200. for tensor_name, tensor in tensors.items():
  201. module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
  202. if module_name not in loras:
  203. lora_embeddings_tensor = None
  204. if embeddings:
  205. assert embedding_modules is not None
  206. embeddings_module = next(
  207. (k for k in embedding_modules if k in module_name),
  208. None)
  209. if embeddings_module:
  210. lora_embeddings_tensor = embeddings[
  211. embedding_modules[embeddings_module]].to(
  212. device=device, dtype=dtype)
  213. if pin_memory:
  214. lora_embeddings_tensor = (
  215. lora_embeddings_tensor.pin_memory())
  216. loras[module_name] = LoRALayerWeights(module_name, rank,
  217. lora_alpha, None, None,
  218. lora_embeddings_tensor)
  219. if is_lora_a:
  220. loras[module_name].lora_a = tensor.to(device=device,
  221. dtype=dtype).t()
  222. if pin_memory:
  223. loras[module_name].lora_a = loras[
  224. module_name].lora_a.pin_memory()
  225. else:
  226. loras[module_name].lora_b = tensor.to(device=device,
  227. dtype=dtype).t()
  228. assert embedding_padding_modules is not None
  229. if any(name in module_name
  230. for name in embedding_padding_modules
  231. ) and target_embedding_padding is not None:
  232. lora_b = loras[module_name].lora_b
  233. assert target_embedding_padding >= lora_b.shape[1]
  234. addition = target_embedding_padding - lora_b.shape[1]
  235. loras[module_name].lora_b = torch.nn.functional.pad(
  236. lora_b, (0, addition))
  237. if pin_memory:
  238. loras[module_name].lora_b = loras[
  239. module_name].lora_b.pin_memory()
  240. for lora in loras.values():
  241. lora.optimize()
  242. return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
  243. @classmethod
  244. def from_local_checkpoint(
  245. cls,
  246. lora_dir: str,
  247. expected_lora_modules: List[str],
  248. *,
  249. max_position_embeddings: Optional[int] = None,
  250. lora_model_id: Optional[int] = None,
  251. device: str = "cuda",
  252. dtype: Optional[torch.dtype] = None,
  253. target_embedding_padding: Optional[int] = None,
  254. embedding_modules: Optional[Dict[str, str]] = None,
  255. embedding_padding_modules: Optional[List[str]] = None,
  256. ) -> "LoRAModel":
  257. """Create a LoRAModel from a local checkpoint.
  258. Args:
  259. lora_dir: The local path that has lora data.
  260. expected_lora_modules: Name of modules that are expected to be
  261. replaced by lora.
  262. max_position_embeddings: Max position embedding length. Used to
  263. scaling the largest context length. If None, the lora model's
  264. context length is not scaled.
  265. lora_model_id: Lora model id. If not given, automatically set by
  266. a global counter.
  267. device: Device where the lora model is loaded.
  268. dtype: dtype of the lora model weights.
  269. Returns:
  270. Loaded LoRA Model.
  271. """
  272. lora_config_path = os.path.join(lora_dir, "adapter_config.json")
  273. lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
  274. lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
  275. new_embeddings_tensor_path = os.path.join(
  276. lora_dir, "new_embeddings.safetensors")
  277. new_embeddings_bin_file_path = os.path.join(lora_dir,
  278. "new_embeddings.bin")
  279. with open(lora_config_path) as f:
  280. config = json.load(f)
  281. target_modules = config["target_modules"]
  282. unexpected_modules = []
  283. for module in target_modules:
  284. # Compatible with more modules, such as:layers.11.self_attn.k_proj
  285. part_name = module.split(".")[-1]
  286. if part_name not in expected_lora_modules:
  287. unexpected_modules.append(module)
  288. # loaded lora's target modules must be a subset of expected_lora_modules
  289. if unexpected_modules:
  290. raise ValueError(
  291. f"While loading {lora_dir}, expected"
  292. f" target modules in {expected_lora_modules}"
  293. f" but received {unexpected_modules}."
  294. f" Please verify that the loaded LoRA module is correct")
  295. if os.path.isfile(lora_tensor_path):
  296. tensors = safetensors.torch.load_file(lora_tensor_path)
  297. elif os.path.isfile(lora_bin_file_path):
  298. tensors = torch.load(lora_bin_file_path)
  299. else:
  300. raise ValueError(f"{lora_dir} doesn't contain tensors")
  301. embeddings = None
  302. if os.path.isfile(new_embeddings_tensor_path):
  303. embeddings = safetensors.torch.load_file(
  304. new_embeddings_tensor_path)
  305. elif os.path.isfile(new_embeddings_bin_file_path):
  306. embeddings = torch.load(new_embeddings_bin_file_path)
  307. rank = config["r"]
  308. lora_alpha = config["lora_alpha"]
  309. context_length = config.get("context_length", None)
  310. scaling_factor = None
  311. if context_length:
  312. if max_position_embeddings is None:
  313. max_position_embeddings = context_length
  314. scaling_factor = float(
  315. math.ceil(context_length / max_position_embeddings))
  316. return cls.from_lora_tensors(
  317. lora_model_id=get_lora_id()
  318. if lora_model_id is None else lora_model_id,
  319. rank=rank,
  320. lora_alpha=lora_alpha,
  321. tensors=tensors,
  322. device=device,
  323. dtype=dtype,
  324. embeddings=embeddings,
  325. target_embedding_padding=target_embedding_padding,
  326. scaling_factor=scaling_factor,
  327. embedding_modules=embedding_modules,
  328. embedding_padding_modules=embedding_padding_modules,
  329. )
  330. class LoRAModelManager:
  331. """A manager that manages multiple LoRA-fine-tuned models."""
  332. def __init__(
  333. self,
  334. model: nn.Module,
  335. max_num_seqs: int,
  336. max_num_batched_tokens: int,
  337. vocab_size: int,
  338. lora_config: LoRAConfig,
  339. ):
  340. """Create a LoRAModelManager and adapter for a given model.
  341. Args:
  342. model: the model to be adapted.
  343. max_num_seqs: the maximum number of sequences model can run in a
  344. single batch.
  345. max_num_batched_tokens: the maximum number of tokens model can run
  346. in a single batch.
  347. vocab_size: the vocab size of the model.
  348. lora_config: the LoRA configuration.
  349. """
  350. self.lora_config = lora_config
  351. self.max_num_seqs = max_num_seqs
  352. assert self.capacity >= self.lora_slots
  353. self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
  354. self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
  355. self.vocab_size = vocab_size
  356. self.long_lora_context: Optional[LongContextLoRAContext] = None
  357. self.base_indices = torch.empty(self.max_num_batched_tokens,
  358. dtype=torch.long,
  359. device="cuda")
  360. self.sampler_indices = torch.empty(self.max_num_batched_tokens,
  361. dtype=torch.long,
  362. device="cuda")
  363. self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
  364. dtype=torch.long,
  365. device="cuda")
  366. self.embeddings_indices = torch.empty(2,
  367. self.max_num_batched_tokens,
  368. dtype=torch.long,
  369. device="cuda")
  370. self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
  371. dtype=torch.long,
  372. device="cuda")
  373. # Scaling factor -> offset to the sin_cos_cache to it.
  374. # Used for long context lora.
  375. self.scaling_factor_to_offset: Dict[float, int] = {}
  376. # 4 is the number of indicies tensors defined above
  377. # base_indices, sampler_indices, sampler_indices_padded,
  378. # embeddings_indices
  379. self.indices_len: List[Optional[int]] = [None] * 4
  380. self.model: nn.Module = model
  381. if hasattr(self.model, "supported_lora_modules"):
  382. self.supported_lora_modules = copy.deepcopy(
  383. self.model.supported_lora_modules)
  384. if lora_config.long_lora_scaling_factors:
  385. # We need to replace rotary emb layer to do batch computation
  386. # for long lora.
  387. self.supported_lora_modules.append("rotary_emb")
  388. self.packed_modules_mapping = copy.deepcopy(
  389. self.model.packed_modules_mapping)
  390. self.packed_modules: Dict[str, List[str]] = {}
  391. self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
  392. self._registered_loras: Dict[int, LoRAModel] = {}
  393. # Dict instead of a Set for compatibility with LRUCache.
  394. self._active_loras: Dict[int, None] = {}
  395. self._last_mapping: Optional[LoRAMapping] = None
  396. self._create_lora_modules()
  397. self.model.lora_manager = self
  398. @property
  399. def capacity(self) -> int:
  400. return self.lora_config.max_cpu_loras
  401. @property
  402. def lora_slots(self) -> int:
  403. return self.lora_config.max_loras
  404. def __len__(self) -> int:
  405. return len(self._registered_loras)
  406. def activate_lora(
  407. self,
  408. lora_id: int,
  409. ) -> bool:
  410. """Move LoRA into a GPU buffer to be used in the forward pass."""
  411. if lora_id in self._active_loras:
  412. return False
  413. first_free_slot = next(
  414. ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
  415. if lora_id is None), None)
  416. if first_free_slot is None:
  417. raise ValueError("No free lora slots")
  418. index, _ = first_free_slot
  419. self._active_loras[lora_id] = None
  420. lora_model = self._registered_loras[lora_id]
  421. logger.debug("Activating LoRA. int id: %d, slot index: %d",
  422. lora_model.id, index)
  423. self.lora_index_to_id[index] = lora_model.id
  424. for module_name, module in self.modules.items():
  425. module_lora = lora_model.get_lora(module_name)
  426. if module_lora:
  427. module_lora.optimize()
  428. module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
  429. module_lora.embeddings_tensor)
  430. else:
  431. module.reset_lora(index)
  432. return True
  433. def _deactivate_lora(self, lora_id: int):
  434. try:
  435. index = self.lora_index_to_id.index(lora_id)
  436. self.lora_index_to_id[index] = None
  437. except ValueError:
  438. pass
  439. def deactivate_lora(self, lora_id: int) -> bool:
  440. """Remove a LoRA from a GPU buffer."""
  441. if lora_id in self._active_loras:
  442. self._deactivate_lora(lora_id)
  443. self._active_loras.pop(lora_id)
  444. return True
  445. return False
  446. def _set_long_lora_context(self, lora: LoRAModel):
  447. if self.long_lora_context is None:
  448. return
  449. if lora.scaling_factor is None:
  450. return
  451. if (lora.scaling_factor not in self.scaling_factor_to_offset):
  452. raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
  453. " has not been initialized.")
  454. offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
  455. if offsets:
  456. self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
  457. def _add_lora(self, lora: LoRAModel):
  458. self._create_merged_loras_inplace(lora)
  459. self._registered_loras[lora.id] = lora
  460. self._set_long_lora_context(lora)
  461. def add_lora(self, lora: LoRAModel) -> bool:
  462. """Add a LoRAModel to the manager CPU cache."""
  463. logger.debug(
  464. "Adding lora. Model id: %d, "
  465. "int id: %d, "
  466. "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
  467. if lora.id not in self._registered_loras:
  468. if len(self._registered_loras) >= self.capacity:
  469. raise RuntimeError("No free LoRA slots.")
  470. self._add_lora(lora)
  471. return True
  472. return False
  473. def remove_lora(self, lora_id: int) -> bool:
  474. """Remove a LoRAModel from the manager CPU cache."""
  475. # TODO: should we check active lora?
  476. self.deactivate_lora(lora_id)
  477. if self.long_lora_context:
  478. self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
  479. return bool(self._registered_loras.pop(lora_id, None))
  480. # TODO see if this can be vectorized
  481. def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
  482. (base_indices, sampler_indices, sampler_indices_padded,
  483. embeddings_indices, long_lora_offsets_tensor,
  484. indices_len) = convert_mapping(mapping, self.lora_index_to_id,
  485. self.lora_slots + 1, self.vocab_size,
  486. self.lora_config.lora_extra_vocab_size,
  487. self.long_lora_context)
  488. self.base_indices[:base_indices.shape[0]].copy_(base_indices)
  489. self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
  490. self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
  491. sampler_indices_padded)
  492. self.embeddings_indices[:embeddings_indices.
  493. shape[0], :embeddings_indices.shape[1]].copy_(
  494. embeddings_indices)
  495. if long_lora_offsets_tensor is not None:
  496. self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
  497. long_lora_offsets_tensor)
  498. else:
  499. self.long_lora_indices.zero_()
  500. # Maintain the reference
  501. self.indices_len[:] = indices_len
  502. def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
  503. if self._last_mapping != lora_mapping:
  504. self._set_lora_mapping(lora_mapping)
  505. self._last_mapping = lora_mapping
  506. def list_loras(self) -> Dict[int, LoRAModel]:
  507. """List all registered LoRAModels."""
  508. return dict(self._registered_loras)
  509. def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
  510. return self._registered_loras.get(lora_id, None)
  511. def remove_all_loras(self):
  512. """Remove all LoRAModels from the manager."""
  513. self._registered_loras.clear()
  514. self.lora_index_to_id = [None] * self.lora_slots
  515. self._active_loras.clear()
  516. def _create_lora_modules(self):
  517. for module_name, module in self.model.named_modules(
  518. remove_duplicate=False):
  519. if not self._match_target_modules(module_name):
  520. continue
  521. parts = module_name.split(".")[-1]
  522. packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
  523. new_module = replace_submodule(
  524. self.model, module_name,
  525. from_layer(module, self.lora_slots, self.lora_config,
  526. packed_moduled_lst, self.model.config))
  527. # LinearScalingRotaryEmbeddingWithLora is used to handle
  528. # long context lora. Register relevant metadata.
  529. if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
  530. self.long_lora_context = LongContextLoRAContext(
  531. new_module.scaling_factors, new_module.rotary_dim)
  532. self.scaling_factor_to_offset = \
  533. new_module.scaling_factor_to_offset
  534. # (yard1): TODO make this more robust
  535. if "lm_head" in module_name:
  536. logits_processor_module = self.model.get_submodule(
  537. "logits_processor")
  538. new_module = replace_submodule(
  539. self.model, "logits_processor",
  540. from_layer_logits_processor(logits_processor_module,
  541. module, self.lora_slots,
  542. self.lora_config,
  543. self.model.config))
  544. self.register_module(module_name, new_module)
  545. self._register_packed_modules(module_name)
  546. new_module.set_mapping(self.base_indices, self.sampler_indices,
  547. self.sampler_indices_padded,
  548. self.embeddings_indices,
  549. self.long_lora_indices, self.indices_len)
  550. def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
  551. assert isinstance(module, BaseLayerWithLoRA)
  552. self.modules[module_name] = module
  553. def create_dummy_lora(
  554. self,
  555. lora_id: int,
  556. rank: int,
  557. scaling_factor: Optional[float],
  558. embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
  559. """Create zero-initialized LoRAModel for warmup."""
  560. model = LoRAModel(lora_id, rank, {}, scaling_factor)
  561. for module_name, module in self.model.named_modules():
  562. if not self._match_target_modules(module_name) or not isinstance(
  563. module, BaseLayerWithLoRA) or isinstance(
  564. module, LinearScalingRotaryEmbeddingWithLora):
  565. continue
  566. parts = module_name.split(".")
  567. if module_name not in self.packed_modules:
  568. assert embedding_modules is not None
  569. if parts[-1] in embedding_modules:
  570. input_dim = (module.base_layer.org_vocab_size +
  571. self.lora_config.lora_extra_vocab_size if
  572. hasattr(module.base_layer, "org_vocab_size")
  573. else module.base_layer.weight.shape[1])
  574. output_dim = module.base_layer.embedding_dim if hasattr(
  575. module.base_layer,
  576. "embedding_dim") else module.base_layer.weight.shape[0]
  577. embeddings_tensor_dim = (module.base_layer.embedding_dim if
  578. hasattr(module.base_layer,
  579. "embedding_dim") else
  580. module.base_layer.weight.shape[1])
  581. lora = LoRALayerWeights.create_dummy_lora_weights(
  582. module_name,
  583. input_dim,
  584. output_dim,
  585. rank,
  586. module.lora_a_stacked.dtype,
  587. "cpu",
  588. embeddings_tensor_dim=embeddings_tensor_dim)
  589. else:
  590. lora = LoRALayerWeights.create_dummy_lora_weights(
  591. module_name,
  592. module.lora_a_stacked.shape[-1],
  593. module.lora_b_stacked.shape[-2],
  594. rank,
  595. module.lora_a_stacked.dtype,
  596. "cpu",
  597. )
  598. lora.optimize()
  599. else:
  600. parts = module_name.split(".")
  601. replacements = self.packed_modules_mapping[parts[-1]]
  602. subloras: List[Optional["LoRALayerWeights"]] = []
  603. for i, r in enumerate(replacements):
  604. lora = LoRALayerWeights.create_dummy_lora_weights(
  605. module_name + "." + r,
  606. module.lora_a_stacked[i].shape[-1],
  607. module.lora_b_stacked[i].shape[-2],
  608. rank,
  609. module.lora_a_stacked[i].dtype,
  610. "cpu",
  611. )
  612. lora.optimize()
  613. subloras.append(lora)
  614. lora = PackedLoRALayerWeights.pack(subloras)
  615. model.loras[module_name] = lora
  616. return model
  617. def _match_target_modules(self, module_name: str):
  618. return any(
  619. re.match(
  620. r".*\.{target_module}$".format(target_module=target_module),
  621. module_name) or target_module == module_name
  622. for target_module in self.supported_lora_modules)
  623. def _register_packed_modules(self, module_full_name: str) -> None:
  624. parts = module_full_name.split(".")
  625. module_name = parts[-1]
  626. replacements = self.packed_modules_mapping.get(module_name, [])
  627. # When replacements is less than or equal to 1, it indicates that this
  628. # module is not a packed module.
  629. if len(replacements) <= 1:
  630. return
  631. prefix = ".".join(parts[:-1])
  632. self.packed_modules[module_full_name] = [
  633. prefix + "." + r if prefix else r for r in replacements
  634. ]
  635. def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
  636. for module_name, new_module_names in self.packed_modules.items():
  637. replacement_loras: List[Optional[LoRALayerWeights]] = []
  638. has_replacement = False
  639. for r in new_module_names:
  640. lora = lora_model.get_lora(r)
  641. replacement_loras.append(lora)
  642. if lora:
  643. has_replacement = True
  644. if not has_replacement:
  645. continue
  646. for i in range(len(replacement_loras)):
  647. if replacement_loras[i]:
  648. continue
  649. replacement_loras[i] = None
  650. lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
  651. replacement_loras)
  652. class LoRALRUCache(LRUCache[LoRAModel]):
  653. def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
  654. bool]):
  655. super().__init__(capacity)
  656. self.deactivate_lora_fn = deactivate_lora_fn
  657. def _on_remove(self, key: int, value: LoRAModel):
  658. logger.debug("Removing LoRA. int id: %d", key)
  659. self.deactivate_lora_fn(key)
  660. return super()._on_remove(key, value)
  661. class LRUCacheLoRAModelManager(LoRAModelManager):
  662. """A model manager that manages multiple LoRAs with LRU cache."""
  663. def __init__(
  664. self,
  665. model: nn.Module,
  666. max_num_seqs: int,
  667. max_num_batched_tokens: int,
  668. vocab_size: int,
  669. lora_config: LoRAConfig,
  670. ):
  671. super().__init__(model, max_num_seqs, max_num_batched_tokens,
  672. vocab_size, lora_config)
  673. self._registered_loras: LoRALRUCache = LoRALRUCache(
  674. self.capacity, self.deactivate_lora)
  675. self._active_loras: LoRALRUCache = LoRALRUCache(
  676. self.lora_slots, self._deactivate_lora)
  677. def list_loras(self) -> Dict[int, LoRAModel]:
  678. """List all registered LoRAModels."""
  679. return dict(self._registered_loras.cache)
  680. def add_lora(self, lora: LoRAModel) -> bool:
  681. """Add a LoRAModel to the manager."""
  682. logger.debug(
  683. "Adding lora. Model id: %d, "
  684. "int id: %d, "
  685. "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
  686. if lora.id not in self._registered_loras:
  687. self._add_lora(lora)
  688. was_added = True
  689. else:
  690. # We always touch to update the LRU cache order
  691. self._registered_loras.touch(lora.id)
  692. was_added = False
  693. return was_added
  694. def activate_lora(
  695. self,
  696. lora_id: int,
  697. ) -> bool:
  698. if lora_id not in self._active_loras and len(
  699. self._active_loras) >= self.lora_slots:
  700. self._active_loras.remove_oldest()
  701. result = super().activate_lora(lora_id)
  702. # We always touch to update the LRU cache order
  703. self._active_loras.touch(lora_id)
  704. return result
  705. def remove_oldest_lora(self) -> bool:
  706. if len(self._registered_loras) > 0:
  707. self._registered_loras.remove_oldest()
  708. return True
  709. return False
  710. def create_lora_manager(
  711. model: nn.Module,
  712. max_num_seqs: int,
  713. max_num_batched_tokens: int,
  714. vocab_size: int,
  715. lora_config: LoRAConfig,
  716. lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
  717. **kwargs) -> LoRAModelManager:
  718. """Create a LoRA adapter for a given model."""
  719. if not hasattr(model, "supported_lora_modules"):
  720. raise ValueError(f"Model {type(model)} is not supported for LoRA.")
  721. lora_manager = lora_manager_cls(
  722. model=model,
  723. max_num_seqs=max_num_seqs,
  724. max_num_batched_tokens=max_num_batched_tokens,
  725. vocab_size=vocab_size,
  726. lora_config=lora_config,
  727. **kwargs)
  728. return lora_manager