models.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835
  1. import copy
  2. import json
  3. import math
  4. import os
  5. import re
  6. from dataclasses import dataclass, field
  7. from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
  8. import safetensors.torch
  9. import torch
  10. from loguru import logger
  11. from torch import nn
  12. from aphrodite.adapter_commons.models import (AdapterLRUCache, AdapterModel,
  13. AdapterModelManager)
  14. from aphrodite.adapter_commons.utils import (add_adapter, deactivate_adapter,
  15. get_adapter, list_adapters,
  16. remove_adapter,
  17. set_adapter_mapping)
  18. from aphrodite.common.config import LoRAConfig
  19. from aphrodite.common.utils import is_pin_memory_available
  20. from aphrodite.lora.layers import (BaseLayerWithLoRA,
  21. LinearScalingRotaryEmbeddingWithLora,
  22. LoRAMapping)
  23. from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
  24. from aphrodite.lora.utils import (from_layer, from_layer_logits_processor,
  25. parse_fine_tuned_lora_name,
  26. replace_submodule)
  27. from aphrodite.modeling.models.interfaces import SupportsLoRA
  28. _GLOBAL_LORA_ID = 0
  29. @dataclass
  30. class LongContextLoRAContext:
  31. """Context for lora adapters that support long context."""
  32. # The scaling factors to support long context lora fine tuned models.
  33. scaling_factors: List[float]
  34. # dimension to apply rotary embedding.
  35. rot_dim: int
  36. # offsets to the sin_cos_cache for each lora_id loaded.
  37. # This value is dynamically modified.
  38. offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
  39. def convert_mapping(
  40. mapping: LoRAMapping,
  41. lora_index_to_id: List[Optional[int]],
  42. max_loras: int,
  43. vocab_size: int,
  44. extra_vocab_size: int,
  45. long_lora_context: Optional[LongContextLoRAContext] = None,
  46. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
  47. Optional[torch.Tensor], List[int]]:
  48. """Converts LoRAMapping to index tensors.
  49. Args:
  50. mapping: LoRAMapping mapping rows in a batch to LoRA ids.
  51. lora_index_to_id: List mapping LoRA ids to LoRA indices.
  52. max_loras: Maximum number of LoRAs.
  53. vocab_size: Model vocab size.
  54. extra_vocab_size: Extra vocab size each LoRA can have.
  55. long_lora_context: Passed if there are long context lora in a batch.
  56. Returns:
  57. A tuple of tensors:
  58. base_indices: Tensor of shape [batch_size] mapping batch rows to
  59. LoRA indices.
  60. sampler_indices: Tensor of shape [batch_size] mapping requests to
  61. LoRA indices for sampler. For generation, this will be the
  62. same as base_indicies. For prefill, this will map requests
  63. to LoRA indices.
  64. sampler_indices_padded: Tensor of shape [batch_size] mapping
  65. requests to LoRA indices for sampler with padding.
  66. Same as sampler_indicies, but -1 is replaced with
  67. max_loras.
  68. embeddings_indices: Tensor of shape [2, batch_size] mapping
  69. requests to embedding indices. First row is for embeddings
  70. added by the LoRAs, second row is for the LoRA.lora_a
  71. embeddings.
  72. long_lora_indices: Tensor of shape [batch_size] mapping
  73. requests to RoPE offsets and rot dims for long LoRAs.
  74. None if long context lora doesn't exist.
  75. indices_len: List of lengths of the above tensors.
  76. Used to index into each tensor. It contains length for
  77. (base_indices, sampler_indices, sampler_indices_padded,
  78. embeddings_indices, long_lora_indices). If long_lora doesn't
  79. exist, it only contains first 4 entries.
  80. """
  81. index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
  82. embedding_indices = index_mapping_indices.copy()
  83. lora_indices = index_mapping_indices.copy()
  84. long_lora_offsets: Optional[torch.Tensor] = None
  85. if long_lora_context:
  86. long_lora_offsets = torch.zeros(len(index_mapping_indices),
  87. device="cuda",
  88. dtype=torch.long)
  89. prompt_mapping: List[int] = [
  90. lora_index_to_id.index(x) if x > 0 else -1
  91. for x in mapping.prompt_mapping
  92. ]
  93. lora_idx = None
  94. for i in range(len(index_mapping_indices)):
  95. # TODO index can be slow. optimize
  96. lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
  97. if index_mapping_indices[i] > 0 else -1)
  98. embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
  99. lora_indices[i] = lora_idx
  100. if long_lora_context:
  101. assert long_lora_offsets is not None
  102. lora_offset: int = long_lora_context.offsets_by_lora_id.get(
  103. index_mapping_indices[i], 0)
  104. long_lora_offsets[i] = lora_offset
  105. indices_list: List[Union[List[int], torch.Tensor]] = [
  106. index_mapping_indices, lora_indices, embedding_indices
  107. ]
  108. if long_lora_context:
  109. assert long_lora_offsets is not None
  110. indices_list.append(long_lora_offsets)
  111. indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
  112. prompt_mapping_tensor = torch.tensor(prompt_mapping,
  113. device="cuda",
  114. dtype=torch.long)
  115. embeddings_indices = torch.stack([
  116. indices[2] * extra_vocab_size,
  117. indices[2] * (vocab_size + extra_vocab_size)
  118. ])
  119. embeddings_indices[embeddings_indices == -1] = max_loras - 1
  120. base_indices = indices[1]
  121. sampler_indices = prompt_mapping_tensor
  122. sampler_indices_padded = sampler_indices.clone()
  123. sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
  124. sampler_indices_padded = (
  125. torch.arange(
  126. 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
  127. (sampler_indices_padded * len(sampler_indices_padded)))
  128. long_lora_indices = None
  129. long_lora_indices_len: Optional[int] = None
  130. if long_lora_context:
  131. long_lora_indices = indices[3]
  132. long_lora_indices_len = long_lora_indices.shape[-1]
  133. # Contain length of indices tensors. Used to index into each tensor.
  134. indices_len = [
  135. base_indices.shape[-1], sampler_indices.shape[-1],
  136. sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
  137. ]
  138. if long_lora_indices_len is not None:
  139. indices_len.append(long_lora_indices_len)
  140. return (base_indices, sampler_indices, sampler_indices_padded,
  141. embeddings_indices, long_lora_indices, indices_len)
  142. def get_lora_id():
  143. global _GLOBAL_LORA_ID
  144. _GLOBAL_LORA_ID += 1
  145. return _GLOBAL_LORA_ID
  146. class LoRAModel(AdapterModel):
  147. """A LoRA fine-tuned model."""
  148. def __init__(
  149. self,
  150. lora_model_id: int,
  151. rank: int,
  152. loras: Dict[str, LoRALayerWeights],
  153. scaling_factor: Optional[float] = None,
  154. ) -> None:
  155. """
  156. Args:
  157. lora_model_id: The integer id for the lora model.
  158. rank: lora rank.
  159. loras: module name -> weights for lora-replaced layers.
  160. scaling_factor: Scaling factor to support long context lora model.
  161. None if the lora is not tuned for long context support.
  162. """
  163. self.id = lora_model_id
  164. # Scaling factor for long context lora model. None if it is not
  165. # fine tuned for the long context.
  166. self.scaling_factor = scaling_factor
  167. assert (lora_model_id >
  168. 0), f"a valid lora id should be greater than 0, got {self.id}"
  169. self.rank = rank
  170. self.loras: Dict[str, LoRALayerWeights] = loras
  171. def clone(self, lora_model_id: int) -> "LoRAModel":
  172. """Return a copy of the object with different ids.
  173. Will share the underlying tensors."""
  174. return self.__class__(
  175. lora_model_id,
  176. rank=self.rank,
  177. loras=self.loras.copy(),
  178. )
  179. @property
  180. def extra_vocab_size(self) -> int:
  181. return max(lora.extra_vocab_size
  182. for lora in self.loras.values()) if self.loras else 0
  183. def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
  184. """Get LoRA for a given module by name"""
  185. return self.loras.get(module_name, None)
  186. # (yard1): TODO see if we can derive target_embedding_padding automatically
  187. @classmethod
  188. def from_lora_tensors(
  189. cls,
  190. lora_model_id: int,
  191. rank: int,
  192. lora_alpha: int,
  193. tensors: Dict[str, torch.Tensor],
  194. device: str = "cuda",
  195. dtype: Optional[torch.dtype] = None,
  196. embeddings: Optional[Dict[str, torch.Tensor]] = None,
  197. target_embedding_padding: Optional[int] = None,
  198. scaling_factor: Optional[float] = None,
  199. embedding_modules: Optional[Dict[str, str]] = None,
  200. embedding_padding_modules: Optional[List[str]] = None,
  201. ) -> "LoRAModel":
  202. """Create a LoRAModel from a dictionary of tensors."""
  203. pin_memory = str(device) == "cpu" and is_pin_memory_available()
  204. loras: Dict[str, LoRALayerWeights] = {}
  205. for tensor_name, tensor in tensors.items():
  206. module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
  207. if module_name not in loras:
  208. lora_embeddings_tensor = None
  209. if embeddings:
  210. assert embedding_modules is not None
  211. embeddings_module = next(
  212. (k for k in embedding_modules if k in module_name),
  213. None)
  214. if embeddings_module:
  215. lora_embeddings_tensor = embeddings[
  216. embedding_modules[embeddings_module]].to(
  217. device=device, dtype=dtype)
  218. if pin_memory:
  219. lora_embeddings_tensor = (
  220. lora_embeddings_tensor.pin_memory())
  221. loras[module_name] = LoRALayerWeights(module_name, rank,
  222. lora_alpha, None, None,
  223. lora_embeddings_tensor)
  224. if is_lora_a:
  225. loras[module_name].lora_a = tensor.to(device=device,
  226. dtype=dtype).t()
  227. if pin_memory:
  228. loras[module_name].lora_a = loras[
  229. module_name].lora_a.pin_memory()
  230. else:
  231. loras[module_name].lora_b = tensor.to(device=device,
  232. dtype=dtype).t()
  233. assert embedding_padding_modules is not None
  234. if any(name in module_name
  235. for name in embedding_padding_modules
  236. ) and target_embedding_padding is not None:
  237. lora_b = loras[module_name].lora_b
  238. assert target_embedding_padding >= lora_b.shape[1]
  239. addition = target_embedding_padding - lora_b.shape[1]
  240. loras[module_name].lora_b = torch.nn.functional.pad(
  241. lora_b, (0, addition))
  242. if pin_memory:
  243. loras[module_name].lora_b = loras[
  244. module_name].lora_b.pin_memory()
  245. for lora in loras.values():
  246. lora.optimize()
  247. return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
  248. @classmethod
  249. def from_local_checkpoint(
  250. cls,
  251. lora_dir: str,
  252. expected_lora_modules: List[str],
  253. *,
  254. max_position_embeddings: Optional[int] = None,
  255. lora_model_id: Optional[int] = None,
  256. device: str = "cuda",
  257. dtype: Optional[torch.dtype] = None,
  258. target_embedding_padding: Optional[int] = None,
  259. embedding_modules: Optional[Dict[str, str]] = None,
  260. embedding_padding_modules: Optional[List[str]] = None,
  261. ) -> "LoRAModel":
  262. """Create a LoRAModel from a local checkpoint.
  263. Args:
  264. lora_dir: The local path that has lora data.
  265. expected_lora_modules: Name of modules that are expected to be
  266. replaced by lora.
  267. max_position_embeddings: Max position embedding length. Used to
  268. scaling the largest context length. If None, the lora model's
  269. context length is not scaled.
  270. lora_model_id: Lora model id. If not given, automatically set by
  271. a global counter.
  272. device: Device where the lora model is loaded.
  273. dtype: dtype of the lora model weights.
  274. Returns:
  275. Loaded LoRA Model.
  276. """
  277. lora_config_path = os.path.join(lora_dir, "adapter_config.json")
  278. lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
  279. lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
  280. new_embeddings_tensor_path = os.path.join(
  281. lora_dir, "new_embeddings.safetensors")
  282. new_embeddings_bin_file_path = os.path.join(lora_dir,
  283. "new_embeddings.bin")
  284. with open(lora_config_path) as f:
  285. config = json.load(f)
  286. if os.path.isfile(lora_tensor_path):
  287. tensors: Dict[str, torch.Tensor] = {}
  288. # Find unexpected modules.
  289. # Use safetensor key as a source of truth to find expected modules.
  290. # in peft if you have target_modules A, B, C and C does not exist
  291. # in the model it won’t error and model will be trained with A, B
  292. # loraified. C won’t exist in the safetensor but it will exist in
  293. # the target_modules of the adapter_config.json.
  294. unexpected_modules = []
  295. with safetensors.safe_open(lora_tensor_path,
  296. framework="pt") as f: # type: ignore
  297. for lora_module in f.keys(): # noqa
  298. module_name, _ = parse_fine_tuned_lora_name(lora_module)
  299. part_name = module_name.split(".")[-1]
  300. if part_name not in expected_lora_modules:
  301. unexpected_modules.append(module_name)
  302. if unexpected_modules:
  303. raise ValueError(
  304. f"While loading {lora_dir}, expected"
  305. f" target modules in {expected_lora_modules}"
  306. f" but received {unexpected_modules}."
  307. f" Please verify that the loaded LoRA module is correct"
  308. )
  309. # Load tensors if there are only expected modules.
  310. for module in f.keys(): # noqa
  311. tensors[module] = f.get_tensor(module)
  312. elif os.path.isfile(lora_bin_file_path):
  313. # When a bin file is provided, we rely on config to find unexpected
  314. # modules.
  315. unexpected_modules = []
  316. target_modules = config["target_modules"]
  317. for module in target_modules:
  318. # Compatible with more modules,
  319. # such as:layers.11.self_attn.k_proj
  320. part_name = module.split(".")[-1]
  321. if part_name not in expected_lora_modules:
  322. unexpected_modules.append(module)
  323. # loaded lora's target modules must be a subset of
  324. # expected_lora_modules. It is not reliable. But there's no
  325. # other better mechanism.
  326. if unexpected_modules:
  327. print(unexpected_modules, "modules")
  328. raise ValueError(
  329. f"While loading {lora_dir}, expected"
  330. f" target modules in {expected_lora_modules}"
  331. f" but received {unexpected_modules}."
  332. f" Please verify that the loaded LoRA module is correct")
  333. tensors = torch.load(lora_bin_file_path)
  334. else:
  335. raise ValueError(f"{lora_dir} doesn't contain tensors")
  336. embeddings = None
  337. if os.path.isfile(new_embeddings_tensor_path):
  338. embeddings = safetensors.torch.load_file(
  339. new_embeddings_tensor_path)
  340. elif os.path.isfile(new_embeddings_bin_file_path):
  341. embeddings = torch.load(new_embeddings_bin_file_path)
  342. rank = config["r"]
  343. lora_alpha = config["lora_alpha"]
  344. context_length = config.get("context_length", None)
  345. scaling_factor = None
  346. if context_length:
  347. if max_position_embeddings is None:
  348. max_position_embeddings = context_length
  349. scaling_factor = float(
  350. math.ceil(context_length / max_position_embeddings))
  351. return cls.from_lora_tensors(
  352. lora_model_id=get_lora_id()
  353. if lora_model_id is None else lora_model_id,
  354. rank=rank,
  355. lora_alpha=lora_alpha,
  356. tensors=tensors,
  357. device=device,
  358. dtype=dtype,
  359. embeddings=embeddings,
  360. target_embedding_padding=target_embedding_padding,
  361. scaling_factor=scaling_factor,
  362. embedding_modules=embedding_modules,
  363. embedding_padding_modules=embedding_padding_modules,
  364. )
  365. class LoRAModelManager(AdapterModelManager):
  366. """A manager that manages multiple LoRA-fine-tuned models."""
  367. def __init__(
  368. self,
  369. model: SupportsLoRA,
  370. max_num_seqs: int,
  371. max_num_batched_tokens: int,
  372. vocab_size: int,
  373. lora_config: LoRAConfig,
  374. ):
  375. """Create a LoRAModelManager and adapter for a given model.
  376. Args:
  377. model: the model to be adapted.
  378. max_num_seqs: the maximum number of sequences model can run in a
  379. single batch.
  380. max_num_batched_tokens: the maximum number of tokens model can run
  381. in a single batch.
  382. vocab_size: the vocab size of the model.
  383. lora_config: the LoRA configuration.
  384. """
  385. self.lora_config = lora_config
  386. self.max_num_seqs = max_num_seqs
  387. assert self.capacity >= self.lora_slots
  388. self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
  389. self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
  390. self.vocab_size = vocab_size
  391. self.long_lora_context: Optional[LongContextLoRAContext] = None
  392. self.base_indices = torch.empty(self.max_num_batched_tokens,
  393. dtype=torch.long,
  394. device="cuda")
  395. self.sampler_indices = torch.empty(self.max_num_batched_tokens,
  396. dtype=torch.long,
  397. device="cuda")
  398. self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
  399. dtype=torch.long,
  400. device="cuda")
  401. self.embeddings_indices = torch.empty(2,
  402. self.max_num_batched_tokens,
  403. dtype=torch.long,
  404. device="cuda")
  405. self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
  406. dtype=torch.long,
  407. device="cuda")
  408. # Scaling factor -> offset to the sin_cos_cache to it.
  409. # Used for long context lora.
  410. self.scaling_factor_to_offset: Dict[float, int] = {}
  411. # 4 is the number of indicies tensors defined above
  412. # base_indices, sampler_indices, sampler_indices_padded,
  413. # embeddings_indices
  414. self.indices_len: List[Optional[int]] = [None] * 4
  415. super().__init__(model)
  416. if hasattr(self.model, "supported_lora_modules"):
  417. self.supported_lora_modules = copy.deepcopy(
  418. self.model.supported_lora_modules)
  419. if lora_config.long_lora_scaling_factors:
  420. # We need to replace rotary emb layer to do batch computation
  421. # for long lora.
  422. self.supported_lora_modules.append("rotary_emb")
  423. self.packed_modules_mapping = copy.deepcopy(
  424. self.model.packed_modules_mapping)
  425. self.packed_modules: Dict[str, List[str]] = {}
  426. self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
  427. # Dict instead of a Set for compatibility with LRUCache.
  428. self._last_mapping: Optional[LoRAMapping] = None
  429. self._create_lora_modules()
  430. self.model.lora_manager = self
  431. self.adapter_type = 'LoRa'
  432. @property
  433. def capacity(self) -> int:
  434. return self.lora_config.max_cpu_loras
  435. @property
  436. def lora_slots(self) -> int:
  437. return self.lora_config.max_loras
  438. @property
  439. def adapter_slots(self) -> int:
  440. return self.lora_slots
  441. def activate_adapter(
  442. self,
  443. lora_id: int,
  444. ) -> bool:
  445. """Move LoRA into a GPU buffer to be used in the forward pass."""
  446. if lora_id in self._active_adapters:
  447. return False
  448. first_free_slot = next(
  449. ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
  450. if lora_id is None), None)
  451. if first_free_slot is None:
  452. raise ValueError("No free lora slots")
  453. index, _ = first_free_slot
  454. self._active_adapters[lora_id] = None
  455. lora_model = self._registered_adapters[lora_id]
  456. logger.debug(f"Activating LoRA. int id: {lora_model.id}, "
  457. f"slot index: {index}")
  458. self.lora_index_to_id[index] = lora_model.id
  459. for module_name, module in self.modules.items():
  460. module_lora = lora_model.get_lora(module_name)
  461. if module_lora:
  462. module_lora.optimize()
  463. module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
  464. module_lora.embeddings_tensor)
  465. else:
  466. module.reset_lora(index)
  467. return True
  468. def _deactivate_adapter(self, lora_id: int):
  469. try:
  470. index = self.lora_index_to_id.index(lora_id)
  471. self.lora_index_to_id[index] = None
  472. except ValueError:
  473. pass
  474. def _set_long_lora_context(self, lora: LoRAModel):
  475. if self.long_lora_context is None:
  476. return
  477. if lora.scaling_factor is None:
  478. return
  479. if (lora.scaling_factor not in self.scaling_factor_to_offset):
  480. raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
  481. " has not been initialized.")
  482. offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
  483. if offsets:
  484. self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
  485. def _add_adapter(self, lora: LoRAModel):
  486. self._create_merged_loras_inplace(lora)
  487. self._registered_adapters[lora.id] = lora
  488. self._set_long_lora_context(lora)
  489. def pin_adapter(self, lora_id: int) -> bool:
  490. """Pin a LoRAModel in the manager cache."""
  491. raise NotImplementedError(
  492. "Pinning is not supported in LoRAModelManager."
  493. "Use LRUCacheLoRAModelManager for pinning") # type: ignore
  494. # TODO see if this can be vectorized
  495. def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
  496. (base_indices, sampler_indices, sampler_indices_padded,
  497. embeddings_indices, long_lora_offsets_tensor,
  498. indices_len) = convert_mapping(mapping, self.lora_index_to_id,
  499. self.lora_slots + 1, self.vocab_size,
  500. self.lora_config.lora_extra_vocab_size,
  501. self.long_lora_context)
  502. self.base_indices[:base_indices.shape[0]].copy_(base_indices)
  503. self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
  504. self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
  505. sampler_indices_padded)
  506. self.embeddings_indices[:embeddings_indices.
  507. shape[0], :embeddings_indices.shape[1]].copy_(
  508. embeddings_indices)
  509. if long_lora_offsets_tensor is not None:
  510. self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
  511. long_lora_offsets_tensor)
  512. else:
  513. self.long_lora_indices.zero_()
  514. # Maintain the reference
  515. self.indices_len[:] = indices_len
  516. def remove_all_adapters(self):
  517. """Remove all LoRAModels from the manager."""
  518. self._registered_adapters.clear()
  519. self.lora_index_to_id = [None] * self.lora_slots
  520. self._active_adapters.clear()
  521. def _create_lora_modules(self):
  522. for module_name, module in self.model.named_modules(
  523. remove_duplicate=False):
  524. if not self._match_target_modules(module_name):
  525. continue
  526. parts = module_name.split(".")[-1]
  527. packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
  528. new_module = replace_submodule(
  529. self.model, module_name,
  530. from_layer(module, self.lora_slots, self.lora_config,
  531. packed_moduled_lst, self.model.config))
  532. # LinearScalingRotaryEmbeddingWithLora is used to handle
  533. # long context lora. Register relevant metadata.
  534. if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
  535. self.long_lora_context = LongContextLoRAContext(
  536. new_module.scaling_factors, new_module.rotary_dim)
  537. self.scaling_factor_to_offset = \
  538. new_module.scaling_factor_to_offset
  539. # (yard1): TODO make this more robust
  540. if "lm_head" in module_name:
  541. logits_processor_module = self.model.get_submodule(
  542. "logits_processor")
  543. new_module = replace_submodule(
  544. self.model, "logits_processor",
  545. from_layer_logits_processor(logits_processor_module,
  546. module, self.lora_slots,
  547. self.lora_config,
  548. self.model.config))
  549. self.register_module(module_name, new_module)
  550. self._register_packed_modules(module_name)
  551. new_module.set_mapping(self.base_indices, self.sampler_indices,
  552. self.sampler_indices_padded,
  553. self.embeddings_indices,
  554. self.long_lora_indices, self.indices_len)
  555. def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
  556. assert isinstance(module, BaseLayerWithLoRA)
  557. self.modules[module_name] = module
  558. def create_dummy_lora(
  559. self,
  560. lora_id: int,
  561. rank: int,
  562. scaling_factor: Optional[float],
  563. embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
  564. """Create zero-initialized LoRAModel for warmup."""
  565. model = LoRAModel(lora_id, rank, {}, scaling_factor)
  566. for module_name, module in self.model.named_modules():
  567. if not self._match_target_modules(module_name) or not isinstance(
  568. module, BaseLayerWithLoRA) or isinstance(
  569. module, LinearScalingRotaryEmbeddingWithLora):
  570. continue
  571. parts = module_name.split(".")
  572. if module_name not in self.packed_modules:
  573. assert embedding_modules is not None
  574. if parts[-1] in embedding_modules:
  575. input_dim = (module.base_layer.org_vocab_size +
  576. self.lora_config.lora_extra_vocab_size if
  577. hasattr(module.base_layer, "org_vocab_size")
  578. else module.base_layer.weight.shape[1])
  579. output_dim = module.base_layer.embedding_dim if hasattr(
  580. module.base_layer,
  581. "embedding_dim") else module.base_layer.weight.shape[0]
  582. embeddings_tensor_dim = (module.base_layer.embedding_dim if
  583. hasattr(module.base_layer,
  584. "embedding_dim") else
  585. module.base_layer.weight.shape[1])
  586. lora = LoRALayerWeights.create_dummy_lora_weights(
  587. module_name,
  588. input_dim,
  589. output_dim,
  590. rank,
  591. module.lora_a_stacked.dtype,
  592. "cpu",
  593. embeddings_tensor_dim=embeddings_tensor_dim)
  594. else:
  595. lora = LoRALayerWeights.create_dummy_lora_weights(
  596. module_name,
  597. module.lora_a_stacked.shape[-1],
  598. module.lora_b_stacked.shape[-2],
  599. rank,
  600. module.lora_a_stacked.dtype,
  601. "cpu",
  602. )
  603. lora.optimize()
  604. else:
  605. parts = module_name.split(".")
  606. replacements = self.packed_modules_mapping[parts[-1]]
  607. subloras: List[Optional["LoRALayerWeights"]] = []
  608. for i, r in enumerate(replacements):
  609. lora = LoRALayerWeights.create_dummy_lora_weights(
  610. module_name + "." + r,
  611. module.lora_a_stacked[i].shape[-1],
  612. module.lora_b_stacked[i].shape[-2],
  613. rank,
  614. module.lora_a_stacked[i].dtype,
  615. "cpu",
  616. )
  617. lora.optimize()
  618. subloras.append(lora)
  619. lora = PackedLoRALayerWeights.pack(subloras)
  620. model.loras[module_name] = lora
  621. return model
  622. def _match_target_modules(self, module_name: str):
  623. return any(
  624. re.match(
  625. r".*\.{target_module}$".format(target_module=target_module),
  626. module_name) or target_module == module_name
  627. for target_module in self.supported_lora_modules)
  628. def _register_packed_modules(self, module_full_name: str) -> None:
  629. parts = module_full_name.split(".")
  630. module_name = parts[-1]
  631. replacements = self.packed_modules_mapping.get(module_name, [])
  632. # When replacements is less than or equal to 1, it indicates that this
  633. # module is not a packed module.
  634. if len(replacements) <= 1:
  635. return
  636. prefix = ".".join(parts[:-1])
  637. self.packed_modules[module_full_name] = [
  638. prefix + "." + r if prefix else r for r in replacements
  639. ]
  640. def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
  641. for module_name, new_module_names in self.packed_modules.items():
  642. replacement_loras: List[Optional[LoRALayerWeights]] = []
  643. has_replacement = False
  644. for r in new_module_names:
  645. lora = lora_model.get_lora(r)
  646. replacement_loras.append(lora)
  647. if lora:
  648. has_replacement = True
  649. if not has_replacement:
  650. continue
  651. for i in range(len(replacement_loras)):
  652. if replacement_loras[i]:
  653. continue
  654. replacement_loras[i] = None
  655. lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
  656. replacement_loras)
  657. def deactivate_adapter(self, adapter_id: int) -> bool:
  658. return deactivate_adapter(adapter_id, self._active_adapters,
  659. self._deactivate_adapter)
  660. def add_adapter(self, adapter: LoRAModel) -> bool:
  661. logger.debug(f"Adding lora. Model id: {adapter.id}, "
  662. f"int id: {adapter.id}, "
  663. f"scaling factor: {adapter.scaling_factor}")
  664. return add_adapter(adapter, self._registered_adapters, self.capacity,
  665. self._add_adapter)
  666. def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
  667. self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
  668. self._set_adapter_mapping)
  669. def remove_adapter(self, adapter_id: int) -> bool:
  670. return remove_adapter(adapter_id, self._registered_adapters,
  671. self.deactivate_adapter)
  672. def list_adapters(self) -> Dict[int, Any]:
  673. return list_adapters(self._registered_adapters)
  674. def get_adapter(self, adapter_id: int) -> Optional[Any]:
  675. return get_adapter(adapter_id, self._registered_adapters)
  676. class LoRALRUCache(AdapterLRUCache[LoRAModel]):
  677. def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
  678. bool]):
  679. super().__init__(capacity, deactivate_lora_fn)
  680. class LRUCacheLoRAModelManager(LoRAModelManager):
  681. """A model manager that manages multiple LoRAs with LRU cache."""
  682. def __init__(
  683. self,
  684. model: nn.Module,
  685. max_num_seqs: int,
  686. max_num_batched_tokens: int,
  687. vocab_size: int,
  688. lora_config: LoRAConfig,
  689. ):
  690. super().__init__(model, max_num_seqs, max_num_batched_tokens,
  691. vocab_size, lora_config)
  692. self._registered_adapters: LoRALRUCache = LoRALRUCache(
  693. self.capacity, self.deactivate_adapter)
  694. self._active_adapters: LoRALRUCache = LoRALRUCache(
  695. self.lora_slots, self._deactivate_adapter)
  696. def list_adapters(self) -> Dict[int, LoRAModel]:
  697. """List all registered LoRAModels."""
  698. return dict(self._registered_adapters.cache)
  699. def add_adapter(self, lora: LoRAModel) -> bool:
  700. """Add a LoRAModel to the manager."""
  701. logger.debug(f"Adding lora. Model id: {lora.id}, "
  702. f"int id: {lora.id}, "
  703. f"scaling factor: {lora.scaling_factor}")
  704. if lora.id not in self._registered_adapters:
  705. self._add_adapter(lora)
  706. was_added = True
  707. else:
  708. # We always touch to update the LRU cache order
  709. self._registered_adapters.touch(lora.id)
  710. was_added = False
  711. return was_added
  712. def activate_adapter(
  713. self,
  714. lora_id: int,
  715. ) -> bool:
  716. if lora_id not in self._active_adapters and len(
  717. self._active_adapters) >= self.lora_slots:
  718. self._active_adapters.remove_oldest()
  719. result = super().activate_adapter(lora_id)
  720. # We always touch to update the LRU cache order
  721. self._active_adapters.touch(lora_id)
  722. return result
  723. def remove_oldest_adapter(self) -> bool:
  724. if len(self._registered_adapters) > 0:
  725. self._registered_adapters.remove_oldest()
  726. return True
  727. return False
  728. def pin_adapter(self, lora_id: int) -> bool:
  729. """Pin a LoRAModel in the manager cache."""
  730. self._pin_lora_in_cpu_cache(lora_id)
  731. self._pin_lora_in_gpu_cache(lora_id)
  732. return True
  733. def _pin_lora_in_cpu_cache(self, lora_id: int):
  734. try:
  735. self._registered_adapters.pin(lora_id)
  736. except ValueError as err:
  737. raise ValueError("Pinning failed. "
  738. f"LoRA {lora_id} is not registered.") from err
  739. def _pin_lora_in_gpu_cache(self, lora_id: int):
  740. if lora_id not in self._active_adapters:
  741. # move lora to gpu if not already active
  742. self.activate_adapter(lora_id)
  743. self._active_adapters.pin(lora_id)
  744. def create_lora_manager(
  745. model: nn.Module,
  746. max_num_seqs: int,
  747. max_num_batched_tokens: int,
  748. vocab_size: int,
  749. lora_config: LoRAConfig,
  750. lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
  751. **kwargs) -> LoRAModelManager:
  752. """Create a LoRA adapter for a given model."""
  753. if not hasattr(model, "supported_lora_modules"):
  754. raise ValueError(f"Model {type(model)} is not supported for LoRA.")
  755. lora_manager = lora_manager_cls(
  756. model=model,
  757. max_num_seqs=max_num_seqs,
  758. max_num_batched_tokens=max_num_batched_tokens,
  759. vocab_size=vocab_size,
  760. lora_config=lora_config,
  761. **kwargs)
  762. return lora_manager