models.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852
  1. import copy
  2. import json
  3. import math
  4. import os
  5. import re
  6. from dataclasses import dataclass, field
  7. from typing import Callable, Dict, List, Optional, Tuple, Type, Union
  8. import safetensors.torch
  9. import torch
  10. from loguru import logger
  11. from torch import nn
  12. from aphrodite.common.config import LoRAConfig
  13. from aphrodite.common.utils import LRUCache, is_pin_memory_available
  14. from aphrodite.lora.layers import (BaseLayerWithLoRA,
  15. LinearScalingRotaryEmbeddingWithLora,
  16. LoRAMapping)
  17. from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
  18. from aphrodite.lora.utils import (from_layer, from_layer_logits_processor,
  19. parse_fine_tuned_lora_name,
  20. replace_submodule)
  21. from aphrodite.modeling.models.interfaces import SupportsLoRA
  22. _GLOBAL_LORA_ID = 0
  23. @dataclass
  24. class LongContextLoRAContext:
  25. """Context for lora adapters that support long context."""
  26. # The scaling factors to support long context lora fine tuned models.
  27. scaling_factors: List[float]
  28. # dimension to apply rotary embedding.
  29. rot_dim: int
  30. # offsets to the sin_cos_cache for each lora_id loaded.
  31. # This value is dynamically modified.
  32. offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
  33. def convert_mapping(
  34. mapping: LoRAMapping,
  35. lora_index_to_id: List[Optional[int]],
  36. max_loras: int,
  37. vocab_size: int,
  38. extra_vocab_size: int,
  39. long_lora_context: Optional[LongContextLoRAContext] = None,
  40. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
  41. Optional[torch.Tensor], List[int]]:
  42. """Converts LoRAMapping to index tensors.
  43. Args:
  44. mapping: LoRAMapping mapping rows in a batch to LoRA ids.
  45. lora_index_to_id: List mapping LoRA ids to LoRA indices.
  46. max_loras: Maximum number of LoRAs.
  47. vocab_size: Model vocab size.
  48. extra_vocab_size: Extra vocab size each LoRA can have.
  49. long_lora_context: Passed if there are long context lora in a batch.
  50. Returns:
  51. A tuple of tensors:
  52. base_indices: Tensor of shape [batch_size] mapping batch rows to
  53. LoRA indices.
  54. sampler_indices: Tensor of shape [batch_size] mapping requests to
  55. LoRA indices for sampler. For generation, this will be the
  56. same as base_indicies. For prefill, this will map requests
  57. to LoRA indices.
  58. sampler_indices_padded: Tensor of shape [batch_size] mapping
  59. requests to LoRA indices for sampler with padding.
  60. Same as sampler_indicies, but -1 is replaced with
  61. max_loras.
  62. embeddings_indices: Tensor of shape [2, batch_size] mapping
  63. requests to embedding indices. First row is for embeddings
  64. added by the LoRAs, second row is for the LoRA.lora_a
  65. embeddings.
  66. long_lora_indices: Tensor of shape [batch_size] mapping
  67. requests to RoPE offsets and rot dims for long LoRAs.
  68. None if long context lora doesn't exist.
  69. indices_len: List of lengths of the above tensors.
  70. Used to index into each tensor. It contains length for
  71. (base_indices, sampler_indices, sampler_indices_padded,
  72. embeddings_indices, long_lora_indices). If long_lora doesn't
  73. exist, it only contains first 4 entries.
  74. """
  75. index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
  76. embedding_indices = index_mapping_indices.copy()
  77. lora_indices = index_mapping_indices.copy()
  78. long_lora_offsets: Optional[torch.Tensor] = None
  79. if long_lora_context:
  80. long_lora_offsets = torch.zeros(len(index_mapping_indices),
  81. device="cuda",
  82. dtype=torch.long)
  83. prompt_mapping: List[int] = [
  84. lora_index_to_id.index(x) if x > 0 else -1
  85. for x in mapping.prompt_mapping
  86. ]
  87. lora_idx = None
  88. for i in range(len(index_mapping_indices)):
  89. # TODO index can be slow. optimize
  90. lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
  91. if index_mapping_indices[i] > 0 else -1)
  92. embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
  93. lora_indices[i] = lora_idx
  94. if long_lora_context:
  95. assert long_lora_offsets is not None
  96. lora_offset: int = long_lora_context.offsets_by_lora_id.get(
  97. index_mapping_indices[i], 0)
  98. long_lora_offsets[i] = lora_offset
  99. # SANG-TODO
  100. # index_mapping_indices[i] = i
  101. indices_list: List[Union[List[int], torch.Tensor]] = [
  102. index_mapping_indices, lora_indices, embedding_indices
  103. ]
  104. if long_lora_context:
  105. assert long_lora_offsets is not None
  106. indices_list.append(long_lora_offsets)
  107. indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
  108. prompt_mapping_tensor = torch.tensor(prompt_mapping,
  109. device="cuda",
  110. dtype=torch.long)
  111. embeddings_indices = torch.stack([
  112. indices[2] * extra_vocab_size,
  113. indices[2] * (vocab_size + extra_vocab_size)
  114. ])
  115. embeddings_indices[embeddings_indices == -1] = max_loras - 1
  116. base_indices = indices[1]
  117. sampler_indices = prompt_mapping_tensor
  118. sampler_indices_padded = sampler_indices.clone()
  119. sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
  120. sampler_indices_padded = (
  121. torch.arange(
  122. 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
  123. (sampler_indices_padded * len(sampler_indices_padded)))
  124. long_lora_indices = None
  125. long_lora_indices_len: Optional[int] = None
  126. if long_lora_context:
  127. long_lora_indices = indices[3]
  128. long_lora_indices_len = long_lora_indices.shape[-1]
  129. # Contain length of indices tensors. Used to index into each tensor.
  130. indices_len = [
  131. base_indices.shape[-1], sampler_indices.shape[-1],
  132. sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
  133. ]
  134. if long_lora_indices_len is not None:
  135. indices_len.append(long_lora_indices_len)
  136. return (base_indices, sampler_indices, sampler_indices_padded,
  137. embeddings_indices, long_lora_indices, indices_len)
  138. def get_lora_id():
  139. global _GLOBAL_LORA_ID
  140. _GLOBAL_LORA_ID += 1
  141. return _GLOBAL_LORA_ID
  142. class LoRAModel:
  143. """A LoRA fine-tuned model."""
  144. def __init__(
  145. self,
  146. lora_model_id: int,
  147. rank: int,
  148. loras: Dict[str, LoRALayerWeights],
  149. scaling_factor: Optional[float] = None,
  150. ) -> None:
  151. """
  152. Args:
  153. lora_model_id: The integer id for the lora model.
  154. rank: lora rank.
  155. loras: module name -> weights for lora-replaced layers.
  156. scaling_factor: Scaling factor to support long context lora model.
  157. None if the lora is not tuned for long context support.
  158. """
  159. self.id = lora_model_id
  160. # Scaling factor for long context lora model. None if it is not
  161. # fine tuned for the long context.
  162. self.scaling_factor = scaling_factor
  163. assert (lora_model_id >
  164. 0), f"a valid lora id should be greater than 0, got {self.id}"
  165. self.rank = rank
  166. self.loras: Dict[str, LoRALayerWeights] = loras
  167. def clone(self, lora_model_id: int) -> "LoRAModel":
  168. """Return a copy of the object with different ids.
  169. Will share the underlying tensors."""
  170. return self.__class__(
  171. lora_model_id,
  172. rank=self.rank,
  173. loras=self.loras.copy(),
  174. )
  175. @property
  176. def extra_vocab_size(self) -> int:
  177. return max(lora.extra_vocab_size
  178. for lora in self.loras.values()) if self.loras else 0
  179. def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
  180. """Get LoRA for a given module by name"""
  181. return self.loras.get(module_name, None)
  182. # (yard1): TODO see if we can derive target_embedding_padding automatically
  183. @classmethod
  184. def from_lora_tensors(
  185. cls,
  186. lora_model_id: int,
  187. rank: int,
  188. lora_alpha: int,
  189. tensors: Dict[str, torch.Tensor],
  190. device: str = "cuda",
  191. dtype: Optional[torch.dtype] = None,
  192. embeddings: Optional[Dict[str, torch.Tensor]] = None,
  193. target_embedding_padding: Optional[int] = None,
  194. scaling_factor: Optional[float] = None,
  195. embedding_modules: Optional[Dict[str, str]] = None,
  196. embedding_padding_modules: Optional[List[str]] = None,
  197. ) -> "LoRAModel":
  198. """Create a LoRAModel from a dictionary of tensors."""
  199. pin_memory = str(device) == "cpu" and is_pin_memory_available()
  200. loras: Dict[str, LoRALayerWeights] = {}
  201. for tensor_name, tensor in tensors.items():
  202. module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
  203. if module_name not in loras:
  204. lora_embeddings_tensor = None
  205. if embeddings:
  206. assert embedding_modules is not None
  207. embeddings_module = next(
  208. (k for k in embedding_modules if k in module_name),
  209. None)
  210. if embeddings_module:
  211. lora_embeddings_tensor = embeddings[
  212. embedding_modules[embeddings_module]].to(
  213. device=device, dtype=dtype)
  214. if pin_memory:
  215. lora_embeddings_tensor = (
  216. lora_embeddings_tensor.pin_memory())
  217. loras[module_name] = LoRALayerWeights(module_name, rank,
  218. lora_alpha, None, None,
  219. lora_embeddings_tensor)
  220. if is_lora_a:
  221. loras[module_name].lora_a = tensor.to(device=device,
  222. dtype=dtype).t()
  223. if pin_memory:
  224. loras[module_name].lora_a = loras[
  225. module_name].lora_a.pin_memory()
  226. else:
  227. loras[module_name].lora_b = tensor.to(device=device,
  228. dtype=dtype).t()
  229. assert embedding_padding_modules is not None
  230. if any(name in module_name
  231. for name in embedding_padding_modules
  232. ) and target_embedding_padding is not None:
  233. lora_b = loras[module_name].lora_b
  234. assert target_embedding_padding >= lora_b.shape[1]
  235. addition = target_embedding_padding - lora_b.shape[1]
  236. loras[module_name].lora_b = torch.nn.functional.pad(
  237. lora_b, (0, addition))
  238. if pin_memory:
  239. loras[module_name].lora_b = loras[
  240. module_name].lora_b.pin_memory()
  241. for lora in loras.values():
  242. lora.optimize()
  243. return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
  244. @classmethod
  245. def from_local_checkpoint(
  246. cls,
  247. lora_dir: str,
  248. expected_lora_modules: List[str],
  249. *,
  250. max_position_embeddings: Optional[int] = None,
  251. lora_model_id: Optional[int] = None,
  252. device: str = "cuda",
  253. dtype: Optional[torch.dtype] = None,
  254. target_embedding_padding: Optional[int] = None,
  255. embedding_modules: Optional[Dict[str, str]] = None,
  256. embedding_padding_modules: Optional[List[str]] = None,
  257. ) -> "LoRAModel":
  258. """Create a LoRAModel from a local checkpoint.
  259. Args:
  260. lora_dir: The local path that has lora data.
  261. expected_lora_modules: Name of modules that are expected to be
  262. replaced by lora.
  263. max_position_embeddings: Max position embedding length. Used to
  264. scaling the largest context length. If None, the lora model's
  265. context length is not scaled.
  266. lora_model_id: Lora model id. If not given, automatically set by
  267. a global counter.
  268. device: Device where the lora model is loaded.
  269. dtype: dtype of the lora model weights.
  270. Returns:
  271. Loaded LoRA Model.
  272. """
  273. lora_config_path = os.path.join(lora_dir, "adapter_config.json")
  274. lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
  275. lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
  276. new_embeddings_tensor_path = os.path.join(
  277. lora_dir, "new_embeddings.safetensors")
  278. new_embeddings_bin_file_path = os.path.join(lora_dir,
  279. "new_embeddings.bin")
  280. with open(lora_config_path) as f:
  281. config = json.load(f)
  282. if os.path.isfile(lora_tensor_path):
  283. tensors: Dict[str, torch.Tensor] = {}
  284. # Find unexpected modules.
  285. # Use safetensor key as a source of truth to find expected modules.
  286. # in peft if you have target_modules A, B, C and C does not exist
  287. # in the model it won’t error and model will be trained with A, B
  288. # loraified. C won’t exist in the safetensor but it will exist in
  289. # the target_modules of the adapter_config.json.
  290. unexpected_modules = []
  291. with safetensors.safe_open(lora_tensor_path,
  292. framework="pt") as f: # type: ignore
  293. for lora_module in f.keys(): # noqa
  294. module_name, _ = parse_fine_tuned_lora_name(lora_module)
  295. part_name = module_name.split(".")[-1]
  296. if part_name not in expected_lora_modules:
  297. unexpected_modules.append(module_name)
  298. if unexpected_modules:
  299. raise ValueError(
  300. f"While loading {lora_dir}, expected"
  301. f" target modules in {expected_lora_modules}"
  302. f" but received {unexpected_modules}."
  303. f" Please verify that the loaded LoRA module is correct"
  304. )
  305. # Load tensors if there are only expected modules.
  306. for module in f.keys(): # noqa
  307. tensors[module] = f.get_tensor(module)
  308. elif os.path.isfile(lora_bin_file_path):
  309. # When a bin file is provided, we rely on config to find unexpected
  310. # modules.
  311. unexpected_modules = []
  312. target_modules = config["target_modules"]
  313. for module in target_modules:
  314. # Compatible with more modules,
  315. # such as:layers.11.self_attn.k_proj
  316. part_name = module.split(".")[-1]
  317. if part_name not in expected_lora_modules:
  318. unexpected_modules.append(module)
  319. # loaded lora's target modules must be a subset of
  320. # expected_lora_modules. It is not reliable. But there's no
  321. # other better mechanism.
  322. if unexpected_modules:
  323. print(unexpected_modules, "modules")
  324. raise ValueError(
  325. f"While loading {lora_dir}, expected"
  326. f" target modules in {expected_lora_modules}"
  327. f" but received {unexpected_modules}."
  328. f" Please verify that the loaded LoRA module is correct")
  329. tensors = torch.load(lora_bin_file_path)
  330. else:
  331. raise ValueError(f"{lora_dir} doesn't contain tensors")
  332. embeddings = None
  333. if os.path.isfile(new_embeddings_tensor_path):
  334. embeddings = safetensors.torch.load_file(
  335. new_embeddings_tensor_path)
  336. elif os.path.isfile(new_embeddings_bin_file_path):
  337. embeddings = torch.load(new_embeddings_bin_file_path)
  338. rank = config["r"]
  339. lora_alpha = config["lora_alpha"]
  340. context_length = config.get("context_length", None)
  341. scaling_factor = None
  342. if context_length:
  343. if max_position_embeddings is None:
  344. max_position_embeddings = context_length
  345. scaling_factor = float(
  346. math.ceil(context_length / max_position_embeddings))
  347. return cls.from_lora_tensors(
  348. lora_model_id=get_lora_id()
  349. if lora_model_id is None else lora_model_id,
  350. rank=rank,
  351. lora_alpha=lora_alpha,
  352. tensors=tensors,
  353. device=device,
  354. dtype=dtype,
  355. embeddings=embeddings,
  356. target_embedding_padding=target_embedding_padding,
  357. scaling_factor=scaling_factor,
  358. embedding_modules=embedding_modules,
  359. embedding_padding_modules=embedding_padding_modules,
  360. )
  361. class LoRAModelManager:
  362. """A manager that manages multiple LoRA-fine-tuned models."""
  363. def __init__(
  364. self,
  365. model: SupportsLoRA,
  366. max_num_seqs: int,
  367. max_num_batched_tokens: int,
  368. vocab_size: int,
  369. lora_config: LoRAConfig,
  370. ):
  371. """Create a LoRAModelManager and adapter for a given model.
  372. Args:
  373. model: the model to be adapted.
  374. max_num_seqs: the maximum number of sequences model can run in a
  375. single batch.
  376. max_num_batched_tokens: the maximum number of tokens model can run
  377. in a single batch.
  378. vocab_size: the vocab size of the model.
  379. lora_config: the LoRA configuration.
  380. """
  381. self.lora_config = lora_config
  382. self.max_num_seqs = max_num_seqs
  383. assert self.capacity >= self.lora_slots
  384. self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
  385. self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
  386. self.vocab_size = vocab_size
  387. self.long_lora_context: Optional[LongContextLoRAContext] = None
  388. self.base_indices = torch.empty(self.max_num_batched_tokens,
  389. dtype=torch.long,
  390. device="cuda")
  391. self.sampler_indices = torch.empty(self.max_num_batched_tokens,
  392. dtype=torch.long,
  393. device="cuda")
  394. self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
  395. dtype=torch.long,
  396. device="cuda")
  397. self.embeddings_indices = torch.empty(2,
  398. self.max_num_batched_tokens,
  399. dtype=torch.long,
  400. device="cuda")
  401. self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
  402. dtype=torch.long,
  403. device="cuda")
  404. # Scaling factor -> offset to the sin_cos_cache to it.
  405. # Used for long context lora.
  406. self.scaling_factor_to_offset: Dict[float, int] = {}
  407. # 4 is the number of indicies tensors defined above
  408. # base_indices, sampler_indices, sampler_indices_padded,
  409. # embeddings_indices
  410. self.indices_len: List[Optional[int]] = [None] * 4
  411. self.model = model
  412. if hasattr(self.model, "supported_lora_modules"):
  413. self.supported_lora_modules = copy.deepcopy(
  414. self.model.supported_lora_modules)
  415. if lora_config.long_lora_scaling_factors:
  416. # We need to replace rotary emb layer to do batch computation
  417. # for long lora.
  418. self.supported_lora_modules.append("rotary_emb")
  419. self.packed_modules_mapping = copy.deepcopy(
  420. self.model.packed_modules_mapping)
  421. self.packed_modules: Dict[str, List[str]] = {}
  422. self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
  423. self._registered_loras: Dict[int, LoRAModel] = {}
  424. # Dict instead of a Set for compatibility with LRUCache.
  425. self._active_loras: Dict[int, None] = {}
  426. self._last_mapping: Optional[LoRAMapping] = None
  427. self._create_lora_modules()
  428. @property
  429. def capacity(self) -> int:
  430. return self.lora_config.max_cpu_loras
  431. @property
  432. def lora_slots(self) -> int:
  433. return self.lora_config.max_loras
  434. def __len__(self) -> int:
  435. return len(self._registered_loras)
  436. def activate_lora(
  437. self,
  438. lora_id: int,
  439. ) -> bool:
  440. """Move LoRA into a GPU buffer to be used in the forward pass."""
  441. if lora_id in self._active_loras:
  442. return False
  443. first_free_slot = next(
  444. ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
  445. if lora_id is None), None)
  446. if first_free_slot is None:
  447. raise ValueError("No free lora slots")
  448. index, _ = first_free_slot
  449. self._active_loras[lora_id] = None
  450. lora_model = self._registered_loras[lora_id]
  451. logger.debug(f"Activating LoRA. int id: {lora_model.id}, "
  452. f"slot index: {index}")
  453. self.lora_index_to_id[index] = lora_model.id
  454. for module_name, module in self.modules.items():
  455. module_lora = lora_model.get_lora(module_name)
  456. if module_lora:
  457. module_lora.optimize()
  458. module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
  459. module_lora.embeddings_tensor)
  460. else:
  461. module.reset_lora(index)
  462. return True
  463. def _deactivate_lora(self, lora_id: int):
  464. try:
  465. index = self.lora_index_to_id.index(lora_id)
  466. self.lora_index_to_id[index] = None
  467. except ValueError:
  468. pass
  469. def deactivate_lora(self, lora_id: int) -> bool:
  470. """Remove a LoRA from a GPU buffer."""
  471. if lora_id in self._active_loras:
  472. self._deactivate_lora(lora_id)
  473. self._active_loras.pop(lora_id)
  474. return True
  475. return False
  476. def _set_long_lora_context(self, lora: LoRAModel):
  477. if self.long_lora_context is None:
  478. return
  479. if lora.scaling_factor is None:
  480. return
  481. if (lora.scaling_factor not in self.scaling_factor_to_offset):
  482. raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
  483. " has not been initialized.")
  484. offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
  485. if offsets:
  486. self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
  487. def _add_lora(self, lora: LoRAModel):
  488. self._create_merged_loras_inplace(lora)
  489. self._registered_loras[lora.id] = lora
  490. self._set_long_lora_context(lora)
  491. def add_lora(self, lora: LoRAModel) -> bool:
  492. """Add a LoRAModel to the manager CPU cache."""
  493. logger.debug(f"Adding lora. Model id: {lora.id}, "
  494. f"int id: {lora.id}, "
  495. f"scaling factor: {lora.scaling_factor}")
  496. if lora.id not in self._registered_loras:
  497. if len(self._registered_loras) >= self.capacity:
  498. raise RuntimeError("No free LoRA slots.")
  499. self._add_lora(lora)
  500. return True
  501. return False
  502. def remove_lora(self, lora_id: int) -> bool:
  503. """Remove a LoRAModel from the manager CPU cache."""
  504. # TODO: should we check active lora?
  505. self.deactivate_lora(lora_id)
  506. if self.long_lora_context:
  507. self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
  508. return bool(self._registered_loras.pop(lora_id, None))
  509. def pin_lora(self, lora_id: int) -> bool:
  510. """Pin a LoRAModel in the manager cache."""
  511. raise NotImplementedError(
  512. "Pinning is not supported in LoRAModelManager."
  513. "Use LRUCacheLoRAModelManager for pinning") # type: ignore
  514. # TODO see if this can be vectorized
  515. def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
  516. (base_indices, sampler_indices, sampler_indices_padded,
  517. embeddings_indices, long_lora_offsets_tensor,
  518. indices_len) = convert_mapping(mapping, self.lora_index_to_id,
  519. self.lora_slots + 1, self.vocab_size,
  520. self.lora_config.lora_extra_vocab_size,
  521. self.long_lora_context)
  522. self.base_indices[:base_indices.shape[0]].copy_(base_indices)
  523. self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
  524. self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
  525. sampler_indices_padded)
  526. self.embeddings_indices[:embeddings_indices.
  527. shape[0], :embeddings_indices.shape[1]].copy_(
  528. embeddings_indices)
  529. if long_lora_offsets_tensor is not None:
  530. self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
  531. long_lora_offsets_tensor)
  532. else:
  533. self.long_lora_indices.zero_()
  534. # Maintain the reference
  535. self.indices_len[:] = indices_len
  536. def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
  537. if self._last_mapping != lora_mapping:
  538. self._set_lora_mapping(lora_mapping)
  539. self._last_mapping = lora_mapping
  540. def list_loras(self) -> Dict[int, LoRAModel]:
  541. """List all registered LoRAModels."""
  542. return dict(self._registered_loras)
  543. def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
  544. return self._registered_loras.get(lora_id, None)
  545. def remove_all_loras(self):
  546. """Remove all LoRAModels from the manager."""
  547. self._registered_loras.clear()
  548. self.lora_index_to_id = [None] * self.lora_slots
  549. self._active_loras.clear()
  550. def _create_lora_modules(self):
  551. for module_name, module in self.model.named_modules(
  552. remove_duplicate=False):
  553. if not self._match_target_modules(module_name):
  554. continue
  555. parts = module_name.split(".")[-1]
  556. packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
  557. new_module = replace_submodule(
  558. self.model, module_name,
  559. from_layer(module, self.lora_slots, self.lora_config,
  560. packed_moduled_lst, self.model.config))
  561. # LinearScalingRotaryEmbeddingWithLora is used to handle
  562. # long context lora. Register relevant metadata.
  563. if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
  564. self.long_lora_context = LongContextLoRAContext(
  565. new_module.scaling_factors, new_module.rotary_dim)
  566. self.scaling_factor_to_offset = \
  567. new_module.scaling_factor_to_offset
  568. # (yard1): TODO make this more robust
  569. if "lm_head" in module_name:
  570. logits_processor_module = self.model.get_submodule(
  571. "logits_processor")
  572. new_module = replace_submodule(
  573. self.model, "logits_processor",
  574. from_layer_logits_processor(logits_processor_module,
  575. module, self.lora_slots,
  576. self.lora_config,
  577. self.model.config))
  578. self.register_module(module_name, new_module)
  579. self._register_packed_modules(module_name)
  580. new_module.set_mapping(self.base_indices, self.sampler_indices,
  581. self.sampler_indices_padded,
  582. self.embeddings_indices,
  583. self.long_lora_indices, self.indices_len)
  584. def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
  585. assert isinstance(module, BaseLayerWithLoRA)
  586. self.modules[module_name] = module
  587. def create_dummy_lora(
  588. self,
  589. lora_id: int,
  590. rank: int,
  591. scaling_factor: Optional[float],
  592. embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
  593. """Create zero-initialized LoRAModel for warmup."""
  594. model = LoRAModel(lora_id, rank, {}, scaling_factor)
  595. for module_name, module in self.model.named_modules():
  596. if not self._match_target_modules(module_name) or not isinstance(
  597. module, BaseLayerWithLoRA) or isinstance(
  598. module, LinearScalingRotaryEmbeddingWithLora):
  599. continue
  600. parts = module_name.split(".")
  601. if module_name not in self.packed_modules:
  602. assert embedding_modules is not None
  603. if parts[-1] in embedding_modules:
  604. input_dim = (module.base_layer.org_vocab_size +
  605. self.lora_config.lora_extra_vocab_size if
  606. hasattr(module.base_layer, "org_vocab_size")
  607. else module.base_layer.weight.shape[1])
  608. output_dim = module.base_layer.embedding_dim if hasattr(
  609. module.base_layer,
  610. "embedding_dim") else module.base_layer.weight.shape[0]
  611. embeddings_tensor_dim = (module.base_layer.embedding_dim if
  612. hasattr(module.base_layer,
  613. "embedding_dim") else
  614. module.base_layer.weight.shape[1])
  615. lora = LoRALayerWeights.create_dummy_lora_weights(
  616. module_name,
  617. input_dim,
  618. output_dim,
  619. rank,
  620. module.lora_a_stacked.dtype,
  621. "cpu",
  622. embeddings_tensor_dim=embeddings_tensor_dim)
  623. else:
  624. lora = LoRALayerWeights.create_dummy_lora_weights(
  625. module_name,
  626. module.lora_a_stacked.shape[-1],
  627. module.lora_b_stacked.shape[-2],
  628. rank,
  629. module.lora_a_stacked.dtype,
  630. "cpu",
  631. )
  632. lora.optimize()
  633. else:
  634. parts = module_name.split(".")
  635. replacements = self.packed_modules_mapping[parts[-1]]
  636. subloras: List[Optional["LoRALayerWeights"]] = []
  637. for i, r in enumerate(replacements):
  638. lora = LoRALayerWeights.create_dummy_lora_weights(
  639. module_name + "." + r,
  640. module.lora_a_stacked[i].shape[-1],
  641. module.lora_b_stacked[i].shape[-2],
  642. rank,
  643. module.lora_a_stacked[i].dtype,
  644. "cpu",
  645. )
  646. lora.optimize()
  647. subloras.append(lora)
  648. lora = PackedLoRALayerWeights.pack(subloras)
  649. model.loras[module_name] = lora
  650. return model
  651. def _match_target_modules(self, module_name: str):
  652. return any(
  653. re.match(
  654. r".*\.{target_module}$".format(target_module=target_module),
  655. module_name) or target_module == module_name
  656. for target_module in self.supported_lora_modules)
  657. def _register_packed_modules(self, module_full_name: str) -> None:
  658. parts = module_full_name.split(".")
  659. module_name = parts[-1]
  660. replacements = self.packed_modules_mapping.get(module_name, [])
  661. # When replacements is less than or equal to 1, it indicates that this
  662. # module is not a packed module.
  663. if len(replacements) <= 1:
  664. return
  665. prefix = ".".join(parts[:-1])
  666. self.packed_modules[module_full_name] = [
  667. prefix + "." + r if prefix else r for r in replacements
  668. ]
  669. def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
  670. for module_name, new_module_names in self.packed_modules.items():
  671. replacement_loras: List[Optional[LoRALayerWeights]] = []
  672. has_replacement = False
  673. for r in new_module_names:
  674. lora = lora_model.get_lora(r)
  675. replacement_loras.append(lora)
  676. if lora:
  677. has_replacement = True
  678. if not has_replacement:
  679. continue
  680. for i in range(len(replacement_loras)):
  681. if replacement_loras[i]:
  682. continue
  683. replacement_loras[i] = None
  684. lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
  685. replacement_loras)
  686. class LoRALRUCache(LRUCache[LoRAModel]):
  687. def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
  688. bool]):
  689. super().__init__(capacity)
  690. self.deactivate_lora_fn = deactivate_lora_fn
  691. def _on_remove(self, key: int, value: LoRAModel):
  692. logger.debug(f"Removing LoRA. int id: {key}")
  693. self.deactivate_lora_fn(key)
  694. return super()._on_remove(key, value)
  695. class LRUCacheLoRAModelManager(LoRAModelManager):
  696. """A model manager that manages multiple LoRAs with LRU cache."""
  697. def __init__(
  698. self,
  699. model: nn.Module,
  700. max_num_seqs: int,
  701. max_num_batched_tokens: int,
  702. vocab_size: int,
  703. lora_config: LoRAConfig,
  704. ):
  705. super().__init__(model, max_num_seqs, max_num_batched_tokens,
  706. vocab_size, lora_config)
  707. self._registered_loras: LoRALRUCache = LoRALRUCache(
  708. self.capacity, self.deactivate_lora)
  709. self._active_loras: LoRALRUCache = LoRALRUCache(
  710. self.lora_slots, self._deactivate_lora)
  711. def list_loras(self) -> Dict[int, LoRAModel]:
  712. """List all registered LoRAModels."""
  713. return dict(self._registered_loras.cache)
  714. def add_lora(self, lora: LoRAModel) -> bool:
  715. """Add a LoRAModel to the manager."""
  716. logger.debug(f"Adding lora. Model id: {lora.id}, "
  717. f"int id: {lora.id}, "
  718. f"scaling factor: {lora.scaling_factor}")
  719. if lora.id not in self._registered_loras:
  720. self._add_lora(lora)
  721. was_added = True
  722. else:
  723. # We always touch to update the LRU cache order
  724. self._registered_loras.touch(lora.id)
  725. was_added = False
  726. return was_added
  727. def activate_lora(
  728. self,
  729. lora_id: int,
  730. ) -> bool:
  731. if lora_id not in self._active_loras and len(
  732. self._active_loras) >= self.lora_slots:
  733. self._active_loras.remove_oldest()
  734. result = super().activate_lora(lora_id)
  735. # We always touch to update the LRU cache order
  736. self._active_loras.touch(lora_id)
  737. return result
  738. def remove_oldest_lora(self) -> bool:
  739. if len(self._registered_loras) > 0:
  740. self._registered_loras.remove_oldest()
  741. return True
  742. return False
  743. def pin_lora(self, lora_id: int) -> bool:
  744. """Pin a LoRAModel in the manager cache."""
  745. self._pin_lora_in_cpu_cache(lora_id)
  746. self._pin_lora_in_gpu_cache(lora_id)
  747. return True
  748. def _pin_lora_in_cpu_cache(self, lora_id: int):
  749. try:
  750. self._registered_loras.pin(lora_id)
  751. except ValueError as err:
  752. raise ValueError("Pinning failed. "
  753. f"LoRA {lora_id} is not registered.") from err
  754. def _pin_lora_in_gpu_cache(self, lora_id: int):
  755. if lora_id not in self._active_loras:
  756. # move lora to gpu if not already active
  757. self.activate_lora(lora_id)
  758. self._active_loras.pin(lora_id)
  759. def create_lora_manager(
  760. model: nn.Module,
  761. max_num_seqs: int,
  762. max_num_batched_tokens: int,
  763. vocab_size: int,
  764. lora_config: LoRAConfig,
  765. lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
  766. **kwargs) -> LoRAModelManager:
  767. """Create a LoRA adapter for a given model."""
  768. if not hasattr(model, "supported_lora_modules"):
  769. raise ValueError(f"Model {type(model)} is not supported for LoRA.")
  770. lora_manager = lora_manager_cls(
  771. model=model,
  772. max_num_seqs=max_num_seqs,
  773. max_num_batched_tokens=max_num_batched_tokens,
  774. vocab_size=vocab_size,
  775. lora_config=lora_config,
  776. **kwargs)
  777. return lora_manager