1
0

layers.py 51 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458
  1. # pylint: disable=unused-argument
  2. import math
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from transformers import PretrainedConfig
  9. from aphrodite.adapter_commons.layers import AdapterMapping
  10. from aphrodite.common.config import LoRAConfig
  11. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size,
  13. split_tensor_along_last_dim,
  14. tensor_model_parallel_all_gather,
  15. tensor_model_parallel_all_reduce,
  16. tensor_model_parallel_gather)
  17. from aphrodite.distributed.utils import divide
  18. from aphrodite.lora.punica import PunicaWrapper
  19. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  20. MergedColumnParallelLinear,
  21. QKVParallelLinear,
  22. ReplicatedLinear,
  23. RowParallelLinear)
  24. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  25. from aphrodite.modeling.layers.rotary_embedding import (
  26. LinearScalingRotaryEmbedding, RotaryEmbedding)
  27. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  28. ParallelLMHead, VocabParallelEmbedding)
  29. if TYPE_CHECKING:
  30. pass
  31. def _get_lora_device(base_layer: nn.Module) -> torch.device:
  32. # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
  33. """Returns the device for where to place the LoRA tensors."""
  34. # unquantizedLinear
  35. if hasattr(base_layer, "weight"):
  36. return base_layer.weight.device
  37. # GPTQ/AWQ/SqueezeLLM
  38. elif hasattr(base_layer, "qweight"):
  39. return base_layer.qweight.device
  40. # marlin
  41. elif hasattr(base_layer, "B"):
  42. return base_layer.B.device
  43. else:
  44. raise ValueError(f"Unsupported base layer: {base_layer}")
  45. def _not_fully_sharded_can_replace(can_replace):
  46. """
  47. decorator which adds the condition of not using fully sharded loras
  48. intended to wrap can_replace_layer()
  49. """
  50. def dec(*args, **kwargs):
  51. decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
  52. condition = (not kwargs["lora_config"].fully_sharded_loras
  53. if decorate else True)
  54. return can_replace(*args, **kwargs) and condition
  55. return dec
  56. class TensorPropertiesMixin:
  57. @property
  58. def dtype(self):
  59. return self._dtype
  60. @dtype.setter
  61. def dtype(self, value):
  62. self._dtype = value
  63. @property
  64. def device(self):
  65. return self._device
  66. @device.setter
  67. def device(self, value):
  68. self._device = value
  69. @dataclass
  70. class LoRAMapping(AdapterMapping):
  71. is_prefill: bool = False
  72. class BaseLayerWithLoRA(nn.Module):
  73. def slice_lora_a(
  74. self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  75. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  76. """Slice lora a if splitting for tensor parallelism."""
  77. ...
  78. def slice_lora_b(
  79. self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  80. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  81. """Slice lora b if splitting with tensor parallelism."""
  82. ...
  83. def create_lora_weights(
  84. self,
  85. max_loras: int,
  86. lora_config: LoRAConfig,
  87. model_config: Optional[PretrainedConfig] = None,
  88. ) -> None:
  89. """Initializes lora matrices."""
  90. ...
  91. def reset_lora(self, index: int):
  92. """Resets the lora weights at index back to 0."""
  93. ...
  94. def set_lora(
  95. self,
  96. index: int,
  97. lora_a: torch.Tensor,
  98. lora_b: torch.Tensor,
  99. embeddings_tensor: Optional[torch.Tensor],
  100. ):
  101. """Overwrites lora tensors at index."""
  102. ...
  103. def set_mapping(
  104. self,
  105. punica_wrapper: PunicaWrapper,
  106. ):
  107. self.punica_wrapper: PunicaWrapper = punica_wrapper
  108. @classmethod
  109. def can_replace_layer(
  110. cls,
  111. source_layer: nn.Module,
  112. lora_config: LoRAConfig,
  113. packed_modules_list: List,
  114. model_config: Optional[PretrainedConfig],
  115. ) -> bool:
  116. """Returns True if the layer can be replaced by this LoRA layer."""
  117. raise NotImplementedError
  118. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA, TensorPropertiesMixin):
  119. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  120. super().__init__()
  121. self.base_layer = base_layer
  122. self.dtype = self.base_layer.weight.dtype
  123. self.device = self.base_layer.weight.device
  124. self.embeddings_slice: Optional[Tuple[int, int]]
  125. self.embeddings_weights: Optional[torch.Tensor]
  126. def create_lora_weights(
  127. self,
  128. max_loras: int,
  129. lora_config: LoRAConfig,
  130. model_config: Optional[PretrainedConfig] = None) -> None:
  131. if self.base_layer.num_added_embeddings_per_partition > 0:
  132. # We can start adding lora weights
  133. self.embeddings_weights = self.base_layer.weight.data[
  134. self.base_layer.num_org_embeddings_per_partition:self.
  135. base_layer.num_org_embeddings_per_partition +
  136. self.base_layer.num_added_embeddings_per_partition]
  137. self.embeddings_slice = (
  138. self.base_layer.shard_indices.added_vocab_start_index -
  139. self.base_layer.org_vocab_size,
  140. self.base_layer.shard_indices.added_vocab_end_index -
  141. self.base_layer.org_vocab_size)
  142. self.base_layer.weight.data[
  143. self.base_layer.num_org_embeddings_per_partition:].fill_(0)
  144. else:
  145. self.embeddings_slice = None
  146. self.embeddings_weights = None
  147. self.embeddings_tensors = torch.zeros((
  148. max_loras,
  149. lora_config.lora_extra_vocab_size,
  150. self.base_layer.embedding_dim,
  151. ),
  152. dtype=self.dtype,
  153. device=self.device)
  154. self.lora_a_stacked = torch.zeros((
  155. max_loras,
  156. self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size,
  157. lora_config.max_lora_rank,
  158. ),
  159. dtype=self.dtype,
  160. device=self.device)
  161. self.lora_b_stacked = torch.zeros(
  162. (
  163. max_loras,
  164. 1,
  165. self.base_layer.embedding_dim,
  166. lora_config.max_lora_rank,
  167. ),
  168. dtype=lora_config.lora_dtype,
  169. device=self.device,
  170. )
  171. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  172. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  173. self.lora_a_stacked.shape[2],
  174. )
  175. def reset_lora(self, index: int):
  176. self.lora_a_stacked[index] = 0
  177. self.lora_b_stacked[index] = 0
  178. self.embeddings_tensors[index] = 0
  179. def set_lora(
  180. self,
  181. index: int,
  182. lora_a: torch.Tensor,
  183. lora_b: torch.Tensor,
  184. embeddings_tensor: Optional[torch.Tensor],
  185. ):
  186. self.reset_lora(index)
  187. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  188. lora_a, non_blocking=True)
  189. self.lora_b_stacked[index,
  190. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  191. lora_b.T, non_blocking=True)
  192. if embeddings_tensor is not None:
  193. self.embeddings_tensors[
  194. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  195. shape[1], ].copy_(embeddings_tensor, non_blocking=True)
  196. if self.embeddings_slice is not None:
  197. # TODO(yard1): Optimize this copy, we don't need to copy
  198. # everything, just the modified part
  199. embeddings = self.embeddings_tensors.view(
  200. self.embeddings_tensors.shape[0] *
  201. self.embeddings_tensors.shape[1],
  202. self.embeddings_tensors.shape[2],
  203. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  204. assert self.embeddings_weights is not None
  205. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  206. def forward(self, x: torch.Tensor) -> torch.Tensor:
  207. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  208. embeddings_indices = self.punica_wrapper.embeddings_indices
  209. indices = embeddings_indices[1].view_as(x)
  210. full_lora_a_embeddings = F.embedding(
  211. x + indices,
  212. self.lora_a_stacked_2d,
  213. )
  214. indices = embeddings_indices[0].view_as(x)
  215. full_output = self.base_layer.forward(
  216. x.add_(indices * added_tokens_mask))
  217. full_output_org = full_output
  218. if full_output.ndim == 3:
  219. full_output = full_output.view(
  220. full_output.shape[0] * full_output.shape[1], -1)
  221. if full_lora_a_embeddings.ndim == 3:
  222. full_lora_a_embeddings = full_lora_a_embeddings.view(
  223. full_lora_a_embeddings.shape[0] *
  224. full_lora_a_embeddings.shape[1],
  225. -1,
  226. )
  227. # Embedding layer only need expand op
  228. self.punica_wrapper.add_expand(full_output,
  229. full_lora_a_embeddings,
  230. self.lora_b_stacked,
  231. add_input=True)
  232. return full_output.view_as(full_output_org)
  233. @classmethod
  234. def can_replace_layer(
  235. cls,
  236. source_layer: nn.Module,
  237. lora_config: LoRAConfig,
  238. packed_modules_list: List,
  239. model_config: Optional[PretrainedConfig],
  240. ) -> bool:
  241. # do not use A*B-style LoRA, try to use modules_to_save
  242. if lora_config.enable_lora_modules_to_save:
  243. return False
  244. return type(source_layer) is VocabParallelEmbedding
  245. class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
  246. def __init__(self, base_layer: ReplicatedLinear) -> None:
  247. super().__init__()
  248. self.base_layer = base_layer
  249. self.input_size = self.base_layer.input_size
  250. self.output_size = self.base_layer.output_size
  251. self.device = _get_lora_device(self.base_layer)
  252. def create_lora_weights(
  253. self,
  254. max_loras: int,
  255. lora_config: LoRAConfig,
  256. model_config: Optional[PretrainedConfig] = None,
  257. ) -> None:
  258. self.lora_config = lora_config
  259. lora_a_output_size = lora_config.max_lora_rank
  260. self.lora_a_stacked = torch.zeros(
  261. max_loras,
  262. 1,
  263. lora_a_output_size,
  264. self.input_size,
  265. dtype=lora_config.lora_dtype,
  266. device=self.device,
  267. )
  268. self.lora_b_stacked = torch.zeros(
  269. max_loras,
  270. 1,
  271. self.output_size,
  272. lora_config.max_lora_rank,
  273. dtype=lora_config.lora_dtype,
  274. device=self.device,
  275. )
  276. def reset_lora(self, index: int):
  277. self.lora_a_stacked[index] = 0
  278. self.lora_b_stacked[index] = 0
  279. def set_lora(
  280. self,
  281. index: int,
  282. lora_a: torch.Tensor,
  283. lora_b: torch.Tensor,
  284. embeddings_tensor: Optional[torch.Tensor],
  285. ):
  286. self.reset_lora(index)
  287. self.lora_a_stacked[index,
  288. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  289. lora_a.T, non_blocking=True)
  290. self.lora_b_stacked[index,
  291. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  292. lora_b.T, non_blocking=True)
  293. def apply(self, x: torch.Tensor,
  294. bias: Optional[torch.Tensor]) -> torch.Tensor:
  295. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  296. self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
  297. self.lora_b_stacked, 1.0)
  298. return output
  299. def forward(self, input_):
  300. """Forward of ReplicatedLinearWithLoRA
  301. Args:
  302. input_: Tensor whose last dimension is `input_size`.
  303. Returns:
  304. - output
  305. - bias
  306. """
  307. bias = (self.base_layer.bias
  308. if not self.base_layer.skip_bias_add else None)
  309. # Matrix multiply.
  310. output = self.apply(input_, bias)
  311. output_bias = (self.base_layer.bias
  312. if self.base_layer.skip_bias_add else None)
  313. return output, output_bias
  314. @classmethod
  315. @_not_fully_sharded_can_replace
  316. def can_replace_layer(
  317. cls,
  318. source_layer: nn.Module,
  319. lora_config: LoRAConfig,
  320. packed_modules_list: List,
  321. model_config: Optional[PretrainedConfig],
  322. ) -> bool:
  323. return type(source_layer) is ReplicatedLinear
  324. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  325. """
  326. LoRA on top of ColumnParallelLinear layer.
  327. LoRA B is sliced for tensor parallelism.
  328. """
  329. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  330. super().__init__()
  331. self.base_layer = base_layer
  332. self.tp_size = get_tensor_model_parallel_world_size()
  333. self.input_size = self.base_layer.input_size
  334. self.output_size = self.base_layer.output_size_per_partition
  335. self.device = _get_lora_device(self.base_layer)
  336. def create_lora_weights(
  337. self,
  338. max_loras: int,
  339. lora_config: LoRAConfig,
  340. model_config: Optional[PretrainedConfig] = None,
  341. ) -> None:
  342. self.lora_config = lora_config
  343. self.tp_size = get_tensor_model_parallel_world_size()
  344. lora_a_output_size_per_partition = (
  345. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  346. else divide(lora_config.max_lora_rank, self.tp_size))
  347. self.lora_a_stacked = torch.zeros(
  348. max_loras,
  349. 1,
  350. lora_a_output_size_per_partition,
  351. self.input_size,
  352. dtype=lora_config.lora_dtype,
  353. device=self.device,
  354. )
  355. self.lora_b_stacked = torch.zeros(
  356. max_loras,
  357. 1,
  358. self.output_size,
  359. lora_config.max_lora_rank,
  360. dtype=lora_config.lora_dtype,
  361. device=self.device,
  362. )
  363. self.output_dim = self.lora_b_stacked.shape[2]
  364. def reset_lora(self, index: int):
  365. self.lora_a_stacked[index] = 0
  366. self.lora_b_stacked[index] = 0
  367. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  368. return lora_a
  369. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  370. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  371. shard_size = self.output_dim
  372. start_idx = tensor_model_parallel_rank * shard_size
  373. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  374. lora_b = lora_b[:, start_idx:end_idx]
  375. return lora_b
  376. def set_lora(
  377. self,
  378. index: int,
  379. lora_a: torch.Tensor,
  380. lora_b: torch.Tensor,
  381. embeddings_tensor: Optional[torch.Tensor],
  382. ):
  383. self.reset_lora(index)
  384. if self.tp_size > 1:
  385. lora_a = self.slice_lora_a(lora_a)
  386. lora_b = self.slice_lora_b(lora_b)
  387. self.lora_a_stacked[index,
  388. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  389. lora_a.T, non_blocking=True)
  390. self.lora_b_stacked[index,
  391. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  392. lora_b.T, non_blocking=True)
  393. def apply(self, x: torch.Tensor,
  394. bias: Optional[torch.Tensor]) -> torch.Tensor:
  395. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  396. self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
  397. self.lora_b_stacked, 1.0)
  398. return output
  399. def forward(self, input_):
  400. """Forward of ColumnParallelLinear
  401. Args:
  402. input_: Tensor whose last dimension is `input_size`.
  403. Returns:
  404. - output
  405. - bias
  406. """
  407. bias = (self.base_layer.bias
  408. if not self.base_layer.skip_bias_add else None)
  409. # Matrix multiply.
  410. output_parallel = self.apply(input_, bias)
  411. if self.base_layer.gather_output:
  412. # All-gather across the partitions.
  413. output = tensor_model_parallel_all_gather(output_parallel)
  414. else:
  415. output = output_parallel
  416. output_bias = (self.base_layer.bias
  417. if self.base_layer.skip_bias_add else None)
  418. return output, output_bias
  419. @classmethod
  420. @_not_fully_sharded_can_replace
  421. def can_replace_layer(
  422. cls,
  423. source_layer: nn.Module,
  424. lora_config: LoRAConfig,
  425. packed_modules_list: List,
  426. model_config: Optional[PretrainedConfig],
  427. ) -> bool:
  428. return type(source_layer) is ColumnParallelLinear or (
  429. type(source_layer) is MergedColumnParallelLinear
  430. and len(packed_modules_list) == 1)
  431. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  432. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  433. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  434. This means we have 2 LoRAs, each applied to one half of the layer.
  435. Both slices must have the same size.
  436. """
  437. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  438. super().__init__(base_layer)
  439. def create_lora_weights(
  440. self,
  441. max_loras: int,
  442. lora_config: LoRAConfig,
  443. model_config: Optional[PretrainedConfig] = None,
  444. ) -> None:
  445. self.lora_config = lora_config
  446. n_slices = 2
  447. if not (len(self.base_layer.output_sizes) == n_slices
  448. and self.base_layer.output_sizes[0]
  449. == self.base_layer.output_sizes[1]):
  450. raise ValueError(
  451. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  452. "the same size.")
  453. self.tp_size = get_tensor_model_parallel_world_size()
  454. self.tp_rank = get_tensor_model_parallel_rank()
  455. lora_a_output_size_per_partition = (
  456. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  457. else divide(lora_config.max_lora_rank, self.tp_size))
  458. self.lora_a_stacked = tuple(
  459. torch.zeros(
  460. max_loras,
  461. 1,
  462. lora_a_output_size_per_partition,
  463. self.input_size,
  464. dtype=lora_config.lora_dtype,
  465. device=self.device,
  466. ) for _ in range(n_slices))
  467. self.lora_b_stacked = tuple(
  468. torch.zeros(
  469. max_loras,
  470. 1,
  471. self.output_size // 2,
  472. lora_config.max_lora_rank,
  473. dtype=lora_config.lora_dtype,
  474. device=self.device,
  475. ) for _ in range(n_slices))
  476. self.output_dim = self.lora_b_stacked[0].shape[2]
  477. def reset_lora(self, index: int):
  478. self.lora_a_stacked[0][index] = 0
  479. self.lora_a_stacked[1][index] = 0
  480. self.lora_b_stacked[0][index] = 0
  481. self.lora_b_stacked[1][index] = 0
  482. def slice_lora_a(
  483. self, lora_a: List[Union[torch.Tensor, None]]
  484. ) -> List[Union[torch.Tensor, None]]:
  485. return lora_a
  486. def slice_lora_b(
  487. self, lora_b: List[Union[torch.Tensor, None]]
  488. ) -> List[Union[torch.Tensor, None]]:
  489. if lora_b[0] is None or lora_b[1] is None:
  490. return lora_b
  491. shard_size = self.output_dim
  492. start_idx = self.tp_rank * shard_size
  493. end_idx = (self.tp_rank + 1) * shard_size
  494. lora_b = [
  495. lora_b[0][:, start_idx:end_idx],
  496. lora_b[1][:, start_idx:end_idx],
  497. ]
  498. return lora_b
  499. def set_lora(
  500. self,
  501. index: int,
  502. lora_a: torch.Tensor,
  503. lora_b: torch.Tensor,
  504. embeddings_tensor: Optional[torch.Tensor],
  505. ):
  506. self.reset_lora(index)
  507. if self.tp_size > 1:
  508. lora_a = self.slice_lora_a(lora_a)
  509. lora_b = self.slice_lora_b(lora_b)
  510. if lora_a[0] is not None:
  511. self.lora_a_stacked[0][
  512. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  513. lora_a[0].T, non_blocking=True)
  514. self.lora_b_stacked[0][
  515. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  516. lora_b[0].T, non_blocking=True)
  517. if lora_a[1] is not None:
  518. self.lora_a_stacked[1][
  519. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  520. lora_a[1].T, non_blocking=True)
  521. self.lora_b_stacked[1][
  522. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  523. lora_b[1].T, non_blocking=True)
  524. def apply(self, x: torch.Tensor,
  525. bias: Optional[torch.Tensor]) -> torch.Tensor:
  526. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  527. self.punica_wrapper.add_lora_packed_nslice(
  528. output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
  529. (self.output_dim, self.output_dim))
  530. return output
  531. @classmethod
  532. @_not_fully_sharded_can_replace
  533. def can_replace_layer(
  534. cls,
  535. source_layer: nn.Module,
  536. lora_config: LoRAConfig,
  537. packed_modules_list: List,
  538. model_config: Optional[PretrainedConfig],
  539. ) -> bool:
  540. return (type(source_layer) is MergedColumnParallelLinear
  541. and len(packed_modules_list) == 2)
  542. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  543. """
  544. ColumnParallelLinear layer that is specifically designed for
  545. qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
  546. only contains a single LoRA within their qkv_proj layer.
  547. During inference with Tensor Parallel, the weights of lora_b
  548. must be accurately partitioned according to the respective ranks.
  549. Q slice may have different shape than K and V slices (which both have
  550. the same shape).
  551. """
  552. def __init__(self, base_layer: QKVParallelLinear) -> None:
  553. super().__init__(base_layer)
  554. self.tp_size = get_tensor_model_parallel_world_size()
  555. self.q_proj_total_size = (self.base_layer.total_num_heads *
  556. self.base_layer.head_size)
  557. self.q_proj_shard_size = (self.base_layer.num_heads *
  558. self.base_layer.head_size)
  559. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  560. self.base_layer.head_size)
  561. self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
  562. self.base_layer.head_size)
  563. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  564. tp_rank = get_tensor_model_parallel_rank()
  565. self.q_shard_id = tp_rank
  566. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  567. lora_b_q = lora_b[:, self.q_proj_shard_size *
  568. self.q_shard_id:self.q_proj_shard_size *
  569. (self.q_shard_id + 1)]
  570. k_offset = self.q_proj_total_size
  571. lora_b_k = lora_b[:, k_offset +
  572. self.kv_proj_shard_size * self.kv_shard_id:k_offset +
  573. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  574. v_offset = k_offset + self.kv_proj_total_size
  575. lora_b_v = lora_b[:, v_offset +
  576. self.kv_proj_shard_size * self.kv_shard_id:v_offset +
  577. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  578. lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
  579. return lora_b
  580. def set_lora(
  581. self,
  582. index: int,
  583. lora_a: torch.Tensor,
  584. lora_b: torch.Tensor,
  585. embeddings_tensor: Optional[torch.Tensor],
  586. ):
  587. self.reset_lora(index)
  588. if self.tp_size > 1:
  589. lora_a = self.slice_lora_a(lora_a)
  590. lora_b = self.slice_lora_b(lora_b)
  591. self.lora_a_stacked[index,
  592. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  593. lora_a.T, non_blocking=True)
  594. self.lora_b_stacked[index,
  595. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  596. lora_b.T, non_blocking=True)
  597. @classmethod
  598. @_not_fully_sharded_can_replace
  599. def can_replace_layer(cls, source_layer: nn.Module,
  600. lora_config: LoRAConfig, packed_modules_list: List,
  601. model_config: Optional[PretrainedConfig]) -> bool:
  602. return type(source_layer) is QKVParallelLinear and len(
  603. packed_modules_list) == 1
  604. class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  605. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  606. packed together in qkv proj fashion
  607. (q_proj + k_proj + v_proj -> qkv_proj).
  608. This means we have 3 LoRAs, each applied to one slice of the layer.
  609. Q slice may have different shape than K and V slices (which both have
  610. the same shape).
  611. """
  612. def __init__(self, base_layer: QKVParallelLinear) -> None:
  613. super().__init__(base_layer)
  614. def create_lora_weights(
  615. self,
  616. max_loras: int,
  617. lora_config: LoRAConfig,
  618. model_config: Optional[PretrainedConfig] = None,
  619. ) -> None:
  620. self.lora_config = lora_config
  621. self.tp_size = get_tensor_model_parallel_world_size()
  622. self.tp_rank = get_tensor_model_parallel_rank()
  623. self.q_proj_shard_size = (self.base_layer.num_heads *
  624. self.base_layer.head_size)
  625. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  626. self.base_layer.head_size)
  627. self.q_shard_id = self.tp_rank
  628. self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
  629. lora_a_output_size_per_partition = (
  630. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  631. else divide(lora_config.max_lora_rank, self.tp_size))
  632. # q, k, v
  633. self.lora_a_stacked = (
  634. torch.zeros(
  635. max_loras,
  636. 1,
  637. lora_a_output_size_per_partition,
  638. self.input_size,
  639. dtype=lora_config.lora_dtype,
  640. device=self.device,
  641. ),
  642. torch.zeros(
  643. max_loras,
  644. 1,
  645. lora_a_output_size_per_partition,
  646. self.input_size,
  647. dtype=lora_config.lora_dtype,
  648. device=self.device,
  649. ),
  650. torch.zeros(
  651. max_loras,
  652. 1,
  653. lora_a_output_size_per_partition,
  654. self.input_size,
  655. dtype=lora_config.lora_dtype,
  656. device=self.device,
  657. ),
  658. )
  659. self.lora_b_stacked = (
  660. torch.zeros(
  661. max_loras,
  662. 1,
  663. self.q_proj_shard_size,
  664. lora_config.max_lora_rank,
  665. dtype=lora_config.lora_dtype,
  666. device=self.device,
  667. ),
  668. torch.zeros(
  669. max_loras,
  670. 1,
  671. self.kv_proj_shard_size,
  672. lora_config.max_lora_rank,
  673. dtype=lora_config.lora_dtype,
  674. device=self.device,
  675. ),
  676. torch.zeros(
  677. max_loras,
  678. 1,
  679. self.kv_proj_shard_size,
  680. lora_config.max_lora_rank,
  681. dtype=lora_config.lora_dtype,
  682. device=self.device,
  683. ),
  684. )
  685. self.output_slices = (
  686. self.q_proj_shard_size,
  687. self.kv_proj_shard_size,
  688. self.kv_proj_shard_size,
  689. )
  690. self.packed_indices: Optional[torch.Tensor] = None
  691. self.standard_indices: Optional[torch.Tensor] = None
  692. # lazily initialized.
  693. self.indices: torch.Tensor
  694. self.indices_len: List[int]
  695. def reset_lora(self, index: int):
  696. self.lora_a_stacked[0][index] = 0
  697. self.lora_b_stacked[0][index] = 0
  698. self.lora_a_stacked[1][index] = 0
  699. self.lora_b_stacked[1][index] = 0
  700. self.lora_a_stacked[2][index] = 0
  701. self.lora_b_stacked[2][index] = 0
  702. def slice_lora_a(
  703. self, lora_a: List[Union[torch.Tensor, None]]
  704. ) -> List[Union[torch.Tensor, None]]:
  705. return lora_a
  706. def slice_lora_b(
  707. self, lora_b: List[Union[torch.Tensor, None]]
  708. ) -> List[Union[torch.Tensor, None]]:
  709. lora_b_q, lora_b_k, lora_b_v = None, None, None
  710. if lora_b[0] is not None:
  711. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  712. self.q_shard_id:self.q_proj_shard_size *
  713. (self.q_shard_id + 1), ]
  714. if lora_b[1] is not None:
  715. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  716. self.kv_shard_id:self.kv_proj_shard_size *
  717. (self.kv_shard_id + 1), ]
  718. if lora_b[2] is not None:
  719. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  720. self.kv_shard_id:self.kv_proj_shard_size *
  721. (self.kv_shard_id + 1), ]
  722. lora_b = [lora_b_q, lora_b_k, lora_b_v]
  723. return lora_b
  724. def set_lora(
  725. self,
  726. index: int,
  727. lora_a: torch.Tensor,
  728. lora_b: torch.Tensor,
  729. embeddings_tensor: Optional[torch.Tensor],
  730. ):
  731. self.reset_lora(index)
  732. if self.tp_size > 1:
  733. lora_a = self.slice_lora_a(lora_a)
  734. lora_b = self.slice_lora_b(lora_b)
  735. if lora_b[0] is not None:
  736. lora_b_q = lora_b[0]
  737. self.lora_b_stacked[0][
  738. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  739. lora_b_q.T, non_blocking=True)
  740. if lora_b[1] is not None:
  741. lora_b_k = lora_b[1]
  742. self.lora_b_stacked[1][
  743. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  744. lora_b_k.T, non_blocking=True)
  745. if lora_b[2] is not None:
  746. lora_b_v = lora_b[2]
  747. self.lora_b_stacked[2][
  748. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  749. lora_b_v.T, non_blocking=True)
  750. if lora_a[0] is not None:
  751. self.lora_a_stacked[0][
  752. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  753. lora_a[0].T, non_blocking=True)
  754. if lora_a[1] is not None:
  755. self.lora_a_stacked[1][
  756. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  757. lora_a[1].T, non_blocking=True)
  758. if lora_a[2] is not None:
  759. self.lora_a_stacked[2][
  760. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  761. lora_a[2].T, non_blocking=True)
  762. def apply(self, x: torch.Tensor,
  763. bias: Optional[torch.Tensor]) -> torch.Tensor:
  764. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  765. self.punica_wrapper.add_lora_packed_nslice(output, x,
  766. self.lora_a_stacked,
  767. self.lora_b_stacked, 1.0,
  768. self.output_slices)
  769. return output
  770. @classmethod
  771. @_not_fully_sharded_can_replace
  772. def can_replace_layer(
  773. cls,
  774. source_layer: nn.Module,
  775. lora_config: LoRAConfig,
  776. packed_modules_list: List,
  777. model_config: Optional[PretrainedConfig],
  778. ) -> bool:
  779. return (type(source_layer) is QKVParallelLinear
  780. and len(packed_modules_list) == 3)
  781. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  782. def __init__(self, base_layer: RowParallelLinear) -> None:
  783. super().__init__()
  784. self.base_layer = base_layer
  785. self.input_size = self.base_layer.input_size_per_partition
  786. self.output_size = self.base_layer.output_size
  787. self.device = _get_lora_device(self.base_layer)
  788. def create_lora_weights(
  789. self,
  790. max_loras: int,
  791. lora_config: LoRAConfig,
  792. model_config: Optional[PretrainedConfig] = None,
  793. ) -> None:
  794. self.lora_config = lora_config
  795. self.tp_rank = get_tensor_model_parallel_rank()
  796. self.lora_a_stacked = torch.zeros(
  797. (
  798. max_loras,
  799. 1,
  800. lora_config.max_lora_rank,
  801. self.input_size,
  802. ),
  803. dtype=lora_config.lora_dtype,
  804. device=self.device,
  805. )
  806. tp_size = get_tensor_model_parallel_world_size()
  807. lora_b_output_size_per_partition = (
  808. self.output_size if not lora_config.fully_sharded_loras else
  809. divide(self.output_size, tp_size))
  810. self.lora_b_stacked = torch.zeros(
  811. (
  812. max_loras,
  813. 1,
  814. lora_b_output_size_per_partition,
  815. lora_config.max_lora_rank,
  816. ),
  817. dtype=lora_config.lora_dtype,
  818. device=self.device,
  819. )
  820. def reset_lora(self, index: int):
  821. self.lora_a_stacked[index] = 0
  822. self.lora_b_stacked[index] = 0
  823. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  824. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  825. shard_size = self.input_size
  826. start_idx = tensor_model_parallel_rank * shard_size
  827. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  828. lora_a = lora_a[start_idx:end_idx, :]
  829. return lora_a
  830. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  831. return lora_b
  832. def set_lora(
  833. self,
  834. index: int,
  835. lora_a: torch.Tensor,
  836. lora_b: torch.Tensor,
  837. embeddings_tensor: Optional[torch.Tensor],
  838. ):
  839. self.reset_lora(index)
  840. if self.base_layer.tp_size > 1:
  841. lora_a = self.slice_lora_a(lora_a)
  842. lora_b = self.slice_lora_b(lora_b)
  843. self.lora_a_stacked[index,
  844. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  845. lora_a.T, non_blocking=True)
  846. self.lora_b_stacked[index,
  847. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  848. lora_b.T, non_blocking=True)
  849. def apply(self, x: torch.Tensor) -> torch.Tensor:
  850. output = self.base_layer.quant_method.apply(self.base_layer, x)
  851. self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
  852. self.lora_b_stacked, 1.0)
  853. return output
  854. def forward(self, input_):
  855. """Forward of RowParallelLinear
  856. Args:
  857. input_: tensor whose last dimension is `input_size`. If
  858. `input_is_parallel` is set, then the last dimension
  859. is `input_size // tp_size`.
  860. Returns:
  861. - output
  862. - bias
  863. """
  864. # Set up backprop all-reduce.
  865. if self.base_layer.input_is_parallel:
  866. input_parallel = input_
  867. else:
  868. # TODO: simplify code below
  869. tp_rank = get_tensor_model_parallel_rank()
  870. splitted_input = split_tensor_along_last_dim(
  871. input_, num_partitions=self.base_layer.tp_size)
  872. input_parallel = splitted_input[tp_rank].contiguous()
  873. # Matrix multiply.
  874. output_parallel = self.apply(input_parallel)
  875. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  876. output_ = tensor_model_parallel_all_reduce(output_parallel)
  877. else:
  878. output_ = output_parallel
  879. if not self.base_layer.skip_bias_add:
  880. output = (output_ + self.base_layer.bias
  881. if self.base_layer.bias is not None else output_)
  882. output_bias = None
  883. else:
  884. output = output_
  885. output_bias = self.base_layer.bias
  886. return output, output_bias
  887. @property
  888. def weight(self):
  889. return (self.base_layer.weight if hasattr(self.base_layer, "weight")
  890. else self.base_layer.qweight)
  891. @classmethod
  892. @_not_fully_sharded_can_replace
  893. def can_replace_layer(
  894. cls,
  895. source_layer: nn.Module,
  896. lora_config: LoRAConfig,
  897. packed_modules_list: List,
  898. model_config: Optional[PretrainedConfig],
  899. ) -> bool:
  900. return type(source_layer) is RowParallelLinear
  901. class LogitsProcessorWithLoRA(BaseLayerWithLoRA, TensorPropertiesMixin):
  902. """
  903. LoRA wrapper for LogitsProcessor, with extra logic to handle the
  904. application of the LoRA adapter and added LoRA vocabulary.
  905. Args:
  906. base_layer: LogitsProcessor layer
  907. hidden_size: hidden size of the model
  908. dtype: data type of the model
  909. device: device of the model
  910. sharded_to_full_mapping: index mapping from sharded vocab to full vocab
  911. received from base_layer.get_sharded_to_full_mapping(). If None,
  912. no reindexing will be done.
  913. """
  914. def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
  915. dtype: torch.dtype, device: torch.device,
  916. sharded_to_full_mapping: Optional[List[int]]) -> None:
  917. super().__init__()
  918. self.base_layer = base_layer
  919. self.hidden_size = hidden_size
  920. self.dtype = dtype
  921. self.device = device
  922. self.tp_size = get_tensor_model_parallel_world_size()
  923. self.tp_rank = get_tensor_model_parallel_rank()
  924. self.sharded_to_full_mapping = sharded_to_full_mapping
  925. @property
  926. def logits_as_input(self):
  927. return self.base_layer.logits_as_input
  928. @property
  929. def vocab_size(self):
  930. return self.base_layer.vocab_size
  931. @property
  932. def scale(self):
  933. return self.base_layer.scale
  934. @property
  935. def soft_cap(self):
  936. return self.base_layer.soft_cap
  937. @property
  938. def use_gather(self):
  939. return self.base_layer.use_gather
  940. @property
  941. def org_vocab_size(self):
  942. return self.base_layer.org_vocab_size
  943. @property
  944. def include_gpu_probs_tensor(self):
  945. return self.base_layer.include_gpu_probs_tensor
  946. @property
  947. def should_modify_greedy_probs_inplace(self):
  948. return self.base_layer.should_modify_greedy_probs_inplace
  949. def create_lora_weights(
  950. self,
  951. max_loras: int,
  952. lora_config: LoRAConfig,
  953. model_config: Optional[PretrainedConfig] = None,
  954. ) -> None:
  955. # TODO: Verify if this condition can be further relaxed
  956. if 32000 < self.base_layer.vocab_size > 257024:
  957. raise ValueError("When using LoRA, vocab size must be "
  958. "32000 >= vocab_size <= 257024, "
  959. f"but got {self.base_layer.vocab_size}.")
  960. self.lora_a_stacked = torch.zeros(
  961. (
  962. max_loras,
  963. 1,
  964. lora_config.max_lora_rank,
  965. self.hidden_size,
  966. ),
  967. dtype=lora_config.lora_dtype,
  968. device=self.device,
  969. )
  970. self.lora_b_stacked = torch.zeros(
  971. (
  972. max_loras,
  973. 1,
  974. # Pad for kernel compatibility
  975. math.ceil(self.base_layer.vocab_size /
  976. lora_config.lora_vocab_padding_size) *
  977. lora_config.lora_vocab_padding_size,
  978. lora_config.max_lora_rank,
  979. ),
  980. dtype=lora_config.lora_dtype,
  981. device=self.device,
  982. )
  983. self.embeddings_tensors = torch.full(
  984. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  985. fill_value=float("-inf"),
  986. dtype=self.dtype,
  987. device=self.device,
  988. )
  989. if self.sharded_to_full_mapping is not None:
  990. self.sharded_to_full_mapping_gpu = torch.tensor(
  991. self.sharded_to_full_mapping,
  992. device=self.device,
  993. dtype=torch.long)
  994. else:
  995. self.sharded_to_full_mapping_gpu = None
  996. def reset_lora(self, index: int):
  997. self.lora_a_stacked[index] = 0
  998. self.lora_b_stacked[index] = 0
  999. self.embeddings_tensors[index] = float("-inf")
  1000. def set_lora(
  1001. self,
  1002. index: int,
  1003. lora_a: torch.Tensor,
  1004. lora_b: torch.Tensor,
  1005. embeddings_tensor: Optional[torch.Tensor],
  1006. ):
  1007. self.reset_lora(index)
  1008. self.lora_a_stacked[index,
  1009. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  1010. lora_a.T, non_blocking=True)
  1011. self.lora_b_stacked[index,
  1012. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  1013. lora_b.T, non_blocking=True)
  1014. if embeddings_tensor is not None:
  1015. self.embeddings_tensors[
  1016. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  1017. shape[1], ] = embeddings_tensor
  1018. def _get_logits(
  1019. self,
  1020. hidden_states: torch.Tensor,
  1021. lm_head: VocabParallelEmbedding,
  1022. embedding_bias: Optional[torch.Tensor] = None,
  1023. ) -> Optional[torch.Tensor]:
  1024. # Get the logits for the next tokens.
  1025. logits = lm_head.linear_method.apply(lm_head, hidden_states)
  1026. if embedding_bias is not None:
  1027. logits += embedding_bias
  1028. logits = tensor_model_parallel_gather(logits)
  1029. if logits is None:
  1030. return None
  1031. if self.sharded_to_full_mapping_gpu is not None:
  1032. # Reindex full logits tensor to ensure 1:1 mapping between
  1033. # index and token_id
  1034. # Example for:
  1035. # org_vocab_size = 4
  1036. # added_vocab_size = 2
  1037. # pad_to_size = 8
  1038. # tp_size = 2
  1039. # indices: [0, 1, 2, 3, 4, 5, 6, 7]
  1040. # token_id: [0, 1, 4, -1, 2, 3, 5, -1]
  1041. # Therefore, the mapping is expected to be:
  1042. # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
  1043. # we get:
  1044. # indices: [0, 1, 2, 3, 4, 5, 6, 7]
  1045. # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
  1046. logits = logits[:, self.sharded_to_full_mapping_gpu]
  1047. lora_logits = torch.empty(
  1048. self.embeddings_tensors.shape[0] + 1,
  1049. self.embeddings_tensors.shape[1],
  1050. hidden_states.shape[0],
  1051. dtype=self.embeddings_tensors.dtype,
  1052. device=self.embeddings_tensors.device,
  1053. )
  1054. torch.matmul(self.embeddings_tensors,
  1055. hidden_states.T,
  1056. out=lora_logits[:-1])
  1057. lora_logits[-1] = float("-inf")
  1058. lora_logits = lora_logits.mT
  1059. indices_padded = self.punica_wrapper.sampler_indices_padded
  1060. lora_logits = (lora_logits.reshape(
  1061. lora_logits.shape[0] * lora_logits.shape[1],
  1062. lora_logits.shape[2],
  1063. ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
  1064. posinf=float("inf"),
  1065. neginf=float("-inf")))
  1066. logits[:,
  1067. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  1068. lora_logits.shape[1], ] = lora_logits
  1069. # LogitsProcessorWithLoRA always using bgmv
  1070. self.punica_wrapper.add_lora_logits(logits, hidden_states,
  1071. self.lora_a_stacked,
  1072. self.lora_b_stacked, 1.0)
  1073. # Remove paddings in vocab (if any).
  1074. logits = logits[:, :self.base_layer.vocab_size]
  1075. return logits
  1076. def forward(self, *args, **kwargs):
  1077. return type(self.base_layer).forward(self, *args, **kwargs)
  1078. @classmethod
  1079. def can_replace_layer(
  1080. cls,
  1081. source_layer: nn.Module,
  1082. lora_config: LoRAConfig,
  1083. packed_modules_list: List,
  1084. model_config: Optional[PretrainedConfig],
  1085. ) -> bool:
  1086. # Special handling for the LogitsProcessor.
  1087. return False
  1088. class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
  1089. """Implements RoPE-scaled embeddings with linear scaling for
  1090. multiple LoRA adapters with a specialized kernel.
  1091. Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
  1092. which can handle multi lora adapters in a specialied kernel.
  1093. """
  1094. def __init__(self, base_layer: RotaryEmbedding) -> None:
  1095. super().__init__()
  1096. self.base_layer = base_layer
  1097. @property
  1098. def scaling_factors(self):
  1099. return self.base_layer.scaling_factors
  1100. @property
  1101. def rotary_dim(self):
  1102. return self.base_layer.rotary_dim
  1103. def create_lora_weights(
  1104. self,
  1105. max_loras: int,
  1106. lora_config: LoRAConfig,
  1107. model_config: Optional[PretrainedConfig] = None,
  1108. ) -> None:
  1109. scaling_factors = (list(lora_config.long_lora_scaling_factors)
  1110. if lora_config.long_lora_scaling_factors else [])
  1111. base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
  1112. self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
  1113. scaling_factors = sorted(
  1114. list(set([base_scaling_factor] + scaling_factors)))
  1115. self.base_layer = LinearScalingRotaryEmbedding(
  1116. self.base_layer.head_size,
  1117. self.base_layer.rotary_dim,
  1118. self.base_layer.max_position_embeddings,
  1119. self.base_layer.base,
  1120. self.base_layer.is_neox_style,
  1121. scaling_factors,
  1122. self.base_layer.dtype,
  1123. )
  1124. def reset_lora(self, index: int):
  1125. ...
  1126. def set_lora(
  1127. self,
  1128. index: int,
  1129. lora_a: torch.Tensor,
  1130. lora_b: torch.Tensor,
  1131. embeddings_tensor: Optional[torch.Tensor],
  1132. ):
  1133. ...
  1134. def forward(
  1135. self,
  1136. positions: torch.Tensor,
  1137. query: torch.Tensor,
  1138. key: torch.Tensor,
  1139. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1140. return self.base_layer(
  1141. positions,
  1142. query,
  1143. key,
  1144. offsets=self.punica_wrapper.long_lora_indices,
  1145. )
  1146. @property
  1147. def scaling_factor_to_offset(self) -> Dict[float, int]:
  1148. return self.base_layer.scaling_factor_to_offset
  1149. @classmethod
  1150. def can_replace_layer(
  1151. cls,
  1152. source_layer: nn.Module,
  1153. lora_config: LoRAConfig,
  1154. packed_modules_list: List,
  1155. model_config: Optional[PretrainedConfig],
  1156. ) -> bool:
  1157. """Returns True if the layer can be replaced by this LoRA layer."""
  1158. return (type(source_layer) is LinearScalingRotaryEmbedding
  1159. or type(source_layer) is RotaryEmbedding)
  1160. def extra_repr(self) -> str:
  1161. return self.base_layer.extra_repr()
  1162. class ModulesToSaveWrapper(BaseLayerWithLoRA, TensorPropertiesMixin):
  1163. """
  1164. LoRA wrapper for lm_head layer, inspired by ModulesToSaveWrapper from peft
  1165. contains the copy of base_layer but with replaced weights
  1166. overrides getattr in a such way that
  1167. returns the attribute of this base_layer copy,
  1168. so clients can call ModuleToSave exactly as base_layer module
  1169. Args:
  1170. base_layer: layer to replace by Wrapper:
  1171. VocabParallelEmbedding (for embed_tokens)
  1172. or ParallelLMHead (for lm_head)
  1173. """
  1174. implemented_layers = ['lm_head', 'embed_tokens']
  1175. def __init__(
  1176. self, base_layer: Union[VocabParallelEmbedding,
  1177. ParallelLMHead]) -> None:
  1178. super().__init__()
  1179. self.base_layer = base_layer
  1180. self.device = _get_lora_device(self.base_layer)
  1181. self.tp_size = get_tensor_model_parallel_world_size()
  1182. self.tp_rank = get_tensor_model_parallel_rank()
  1183. @property
  1184. def padded_vocab_size(self):
  1185. # number of embeddings with paddings and with max_lora_extra_vocab_size
  1186. return self.base_layer.num_embeddings_padded
  1187. @property
  1188. def org_vocab_size(self):
  1189. return self.base_layer.org_vocab_size
  1190. @property
  1191. def embedding_dim(self):
  1192. return self.base_layer.embedding_dim
  1193. @property
  1194. def bias(self):
  1195. return self.base_layer.bias
  1196. @property
  1197. def linear_method(self):
  1198. if self.punica_wrapper.no_lora:
  1199. return self.base_layer.linear_method
  1200. return self
  1201. @property
  1202. def weight(self):
  1203. return self.base_layer.weight
  1204. def apply(self, lm_head: 'ModulesToSaveWrapper',
  1205. hidden_states: torch.Tensor,
  1206. bias: Optional[torch.Tensor]) -> torch.Tensor:
  1207. assert isinstance(self.base_layer, ParallelLMHead)
  1208. logits = self.punica_wrapper.bgmv_sample(hidden_states,
  1209. self._lora_tensors,
  1210. self.base_layer.weight)
  1211. if bias is not None:
  1212. logits += bias
  1213. return logits
  1214. def embedding(self, embed_tokens: 'ModulesToSaveWrapper',
  1215. masked_input: torch.LongTensor):
  1216. assert isinstance(self.base_layer, VocabParallelEmbedding)
  1217. embeddings = self.punica_wrapper.bgmv_embedding(
  1218. masked_input, self._lora_tensors, self.base_layer.weight)
  1219. return embeddings
  1220. def create_lora_weights(
  1221. self,
  1222. max_loras: int,
  1223. lora_config: LoRAConfig,
  1224. model_config: Optional[PretrainedConfig] = None,
  1225. ) -> None:
  1226. self.dtype = lora_config.lora_dtype
  1227. # lora_tensors - lm_head tensors in case of ParallelLMHead base
  1228. # or embed_tokens tensors in case of VocabParallelEmbedding
  1229. self._lora_tensors = torch.zeros(
  1230. (max_loras, self.padded_vocab_size, self.base_layer.embedding_dim),
  1231. dtype=self.base_layer.weight.dtype,
  1232. device=self.device,
  1233. )
  1234. for index in range(max_loras):
  1235. self.reset_lora(index)
  1236. def reset_lora(self, index: int):
  1237. weights = self.base_layer.weight
  1238. self._lora_tensors[index, :weights.shape[0], :weights.shape[1]].copy_(
  1239. weights, non_blocking=True)
  1240. def set_lora(
  1241. self,
  1242. index: int,
  1243. lora_a: Optional[torch.Tensor],
  1244. lora_b: torch.Tensor,
  1245. embeddings_tensor: Optional[torch.Tensor],
  1246. ):
  1247. assert lora_a is None
  1248. assert embeddings_tensor is None
  1249. self.reset_lora(index)
  1250. self._lora_tensors[index, :lora_b.shape[0], :lora_b.shape[1]].copy_(
  1251. lora_b, non_blocking=True)
  1252. def forward(self, *args, **kwargs):
  1253. return type(self.base_layer).forward(self, *args, **kwargs)
  1254. @classmethod
  1255. def can_replace_layer(
  1256. cls,
  1257. source_layer: nn.Module,
  1258. lora_config: LoRAConfig,
  1259. packed_modules_list: List,
  1260. model_config: Optional[PretrainedConfig],
  1261. ) -> bool:
  1262. if not lora_config.enable_lora_modules_to_save:
  1263. return False
  1264. return type(source_layer) in (ParallelLMHead, VocabParallelEmbedding)