layers.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308
  1. # pylint: disable=unused-argument
  2. import math
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from transformers import PretrainedConfig
  9. from aphrodite.adapter_commons.layers import AdapterMapping
  10. from aphrodite.common.config import LoRAConfig
  11. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size,
  13. split_tensor_along_last_dim,
  14. tensor_model_parallel_all_gather,
  15. tensor_model_parallel_all_reduce,
  16. tensor_model_parallel_gather)
  17. from aphrodite.distributed.utils import divide
  18. from aphrodite.lora.punica import PunicaWrapper
  19. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  20. MergedColumnParallelLinear,
  21. QKVParallelLinear,
  22. ReplicatedLinear,
  23. RowParallelLinear)
  24. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  25. from aphrodite.modeling.layers.rotary_embedding import (
  26. LinearScalingRotaryEmbedding, RotaryEmbedding)
  27. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  28. VocabParallelEmbedding)
  29. if TYPE_CHECKING:
  30. pass
  31. def _get_lora_device(base_layer: nn.Module) -> torch.device:
  32. # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
  33. """Returns the device for where to place the LoRA tensors."""
  34. # unquantizedLinear
  35. if hasattr(base_layer, "weight"):
  36. return base_layer.weight.device
  37. # GPTQ/AWQ/SqueezeLLM
  38. elif hasattr(base_layer, "qweight"):
  39. return base_layer.qweight.device
  40. # marlin
  41. elif hasattr(base_layer, "B"):
  42. return base_layer.B.device
  43. else:
  44. raise ValueError(f"Unsupported base layer: {base_layer}")
  45. def _not_fully_sharded_can_replace(can_replace):
  46. """
  47. decorator which adds the condition of not using fully sharded loras
  48. intended to wrap can_replace_layer()
  49. """
  50. def dec(*args, **kwargs):
  51. decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
  52. condition = (not kwargs["lora_config"].fully_sharded_loras
  53. if decorate else True)
  54. return can_replace(*args, **kwargs) and condition
  55. return dec
  56. @dataclass
  57. class LoRAMapping(AdapterMapping):
  58. is_prefill: bool = False
  59. class BaseLayerWithLoRA(nn.Module):
  60. def slice_lora_a(
  61. self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  62. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  63. """Slice lora a if splitting for tensor parallelism."""
  64. ...
  65. def slice_lora_b(
  66. self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  67. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  68. """Slice lora b if splitting with tensor parallelism."""
  69. ...
  70. def create_lora_weights(
  71. self,
  72. max_loras: int,
  73. lora_config: LoRAConfig,
  74. model_config: Optional[PretrainedConfig] = None,
  75. ) -> None:
  76. """Initializes lora matrices."""
  77. ...
  78. def reset_lora(self, index: int):
  79. """Resets the lora weights at index back to 0."""
  80. ...
  81. def set_lora(
  82. self,
  83. index: int,
  84. lora_a: torch.Tensor,
  85. lora_b: torch.Tensor,
  86. embeddings_tensor: Optional[torch.Tensor],
  87. ):
  88. """Overwrites lora tensors at index."""
  89. ...
  90. def set_mapping(
  91. self,
  92. punica_wrapper: PunicaWrapper,
  93. ):
  94. self.punica_wrapper: PunicaWrapper = punica_wrapper
  95. @classmethod
  96. def can_replace_layer(
  97. cls,
  98. source_layer: nn.Module,
  99. lora_config: LoRAConfig,
  100. packed_modules_list: List,
  101. model_config: Optional[PretrainedConfig],
  102. ) -> bool:
  103. """Returns True if the layer can be replaced by this LoRA layer."""
  104. raise NotImplementedError
  105. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
  106. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  107. super().__init__()
  108. self.base_layer = base_layer
  109. self.embeddings_slice: Optional[Tuple[int, int]]
  110. self.embeddings_weights: Optional[torch.Tensor]
  111. def create_lora_weights(
  112. self,
  113. max_loras: int,
  114. lora_config: LoRAConfig,
  115. model_config: Optional[PretrainedConfig] = None) -> None:
  116. if self.base_layer.num_added_embeddings_per_partition > 0:
  117. # We can start adding lora weights
  118. self.embeddings_weights = self.base_layer.weight.data[
  119. self.base_layer.num_org_embeddings_per_partition:self.
  120. base_layer.num_org_embeddings_per_partition +
  121. self.base_layer.num_added_embeddings_per_partition]
  122. self.embeddings_slice = (
  123. self.base_layer.shard_indices.added_vocab_start_index -
  124. self.base_layer.org_vocab_size,
  125. self.base_layer.shard_indices.added_vocab_end_index -
  126. self.base_layer.org_vocab_size)
  127. self.base_layer.weight.data[
  128. self.base_layer.num_org_embeddings_per_partition:].fill_(0)
  129. else:
  130. self.embeddings_slice = None
  131. self.embeddings_weights = None
  132. self.embeddings_tensors = torch.zeros(
  133. (
  134. max_loras,
  135. lora_config.lora_extra_vocab_size,
  136. self.base_layer.embedding_dim,
  137. ),
  138. dtype=self.base_layer.weight.dtype,
  139. device=self.base_layer.weight.device,
  140. )
  141. self.lora_a_stacked = torch.zeros(
  142. (
  143. max_loras,
  144. self.base_layer.org_vocab_size +
  145. lora_config.lora_extra_vocab_size,
  146. lora_config.max_lora_rank,
  147. ),
  148. dtype=lora_config.lora_dtype,
  149. device=self.base_layer.weight.device,
  150. )
  151. self.lora_b_stacked = torch.zeros(
  152. (
  153. max_loras,
  154. 1,
  155. self.base_layer.embedding_dim,
  156. lora_config.max_lora_rank,
  157. ),
  158. dtype=lora_config.lora_dtype,
  159. device=self.base_layer.weight.device,
  160. )
  161. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  162. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  163. self.lora_a_stacked.shape[2],
  164. )
  165. def reset_lora(self, index: int):
  166. self.lora_a_stacked[index] = 0
  167. self.lora_b_stacked[index] = 0
  168. self.embeddings_tensors[index] = 0
  169. def set_lora(
  170. self,
  171. index: int,
  172. lora_a: torch.Tensor,
  173. lora_b: torch.Tensor,
  174. embeddings_tensor: Optional[torch.Tensor],
  175. ):
  176. self.reset_lora(index)
  177. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  178. lora_a, non_blocking=True)
  179. self.lora_b_stacked[index,
  180. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  181. lora_b.T, non_blocking=True)
  182. if embeddings_tensor is not None:
  183. self.embeddings_tensors[
  184. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  185. shape[1], ].copy_(embeddings_tensor, non_blocking=True)
  186. if self.embeddings_slice is not None:
  187. # TODO(yard1): Optimize this copy, we don't need to copy
  188. # everything, just the modified part
  189. embeddings = self.embeddings_tensors.view(
  190. self.embeddings_tensors.shape[0] *
  191. self.embeddings_tensors.shape[1],
  192. self.embeddings_tensors.shape[2],
  193. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  194. assert self.embeddings_weights is not None
  195. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  196. def forward(self, x: torch.Tensor) -> torch.Tensor:
  197. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  198. embeddings_indices = self.punica_wrapper.embeddings_indices
  199. indices = embeddings_indices[1].view_as(x)
  200. full_lora_a_embeddings = F.embedding(
  201. x + indices,
  202. self.lora_a_stacked_2d,
  203. )
  204. indices = embeddings_indices[0].view_as(x)
  205. full_output = self.base_layer.forward(
  206. x.add_(indices * added_tokens_mask))
  207. full_output_org = full_output
  208. if full_output.ndim == 3:
  209. full_output = full_output.view(
  210. full_output.shape[0] * full_output.shape[1], -1)
  211. if full_lora_a_embeddings.ndim == 3:
  212. full_lora_a_embeddings = full_lora_a_embeddings.view(
  213. full_lora_a_embeddings.shape[0] *
  214. full_lora_a_embeddings.shape[1],
  215. -1,
  216. )
  217. # Embedding layer only need expand op
  218. self.punica_wrapper.add_expand(full_output,
  219. full_lora_a_embeddings,
  220. self.lora_b_stacked,
  221. add_input=True)
  222. return full_output.view_as(full_output_org)
  223. @classmethod
  224. def can_replace_layer(
  225. cls,
  226. source_layer: nn.Module,
  227. lora_config: LoRAConfig,
  228. packed_modules_list: List,
  229. model_config: Optional[PretrainedConfig],
  230. ) -> bool:
  231. return type(source_layer) is VocabParallelEmbedding
  232. class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
  233. def __init__(self, base_layer: ReplicatedLinear) -> None:
  234. super().__init__()
  235. self.base_layer = base_layer
  236. self.input_size = self.base_layer.input_size
  237. self.output_size = self.base_layer.output_size
  238. self.device = _get_lora_device(self.base_layer)
  239. def create_lora_weights(
  240. self,
  241. max_loras: int,
  242. lora_config: LoRAConfig,
  243. model_config: Optional[PretrainedConfig] = None,
  244. ) -> None:
  245. self.lora_config = lora_config
  246. lora_a_output_size = lora_config.max_lora_rank
  247. self.lora_a_stacked = torch.zeros(
  248. max_loras,
  249. 1,
  250. lora_a_output_size,
  251. self.input_size,
  252. dtype=lora_config.lora_dtype,
  253. device=self.device,
  254. )
  255. self.lora_b_stacked = torch.zeros(
  256. max_loras,
  257. 1,
  258. self.output_size,
  259. lora_config.max_lora_rank,
  260. dtype=lora_config.lora_dtype,
  261. device=self.device,
  262. )
  263. def reset_lora(self, index: int):
  264. self.lora_a_stacked[index] = 0
  265. self.lora_b_stacked[index] = 0
  266. def set_lora(
  267. self,
  268. index: int,
  269. lora_a: torch.Tensor,
  270. lora_b: torch.Tensor,
  271. embeddings_tensor: Optional[torch.Tensor],
  272. ):
  273. self.reset_lora(index)
  274. self.lora_a_stacked[index,
  275. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  276. lora_a.T, non_blocking=True)
  277. self.lora_b_stacked[index,
  278. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  279. lora_b.T, non_blocking=True)
  280. def apply(self, x: torch.Tensor,
  281. bias: Optional[torch.Tensor]) -> torch.Tensor:
  282. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  283. self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
  284. self.lora_b_stacked, 1.0)
  285. return output
  286. def forward(self, input_):
  287. """Forward of ReplicatedLinearWithLoRA
  288. Args:
  289. input_: Tensor whose last dimension is `input_size`.
  290. Returns:
  291. - output
  292. - bias
  293. """
  294. bias = (self.base_layer.bias
  295. if not self.base_layer.skip_bias_add else None)
  296. # Matrix multiply.
  297. output = self.apply(input_, bias)
  298. output_bias = (self.base_layer.bias
  299. if self.base_layer.skip_bias_add else None)
  300. return output, output_bias
  301. @classmethod
  302. @_not_fully_sharded_can_replace
  303. def can_replace_layer(
  304. cls,
  305. source_layer: nn.Module,
  306. lora_config: LoRAConfig,
  307. packed_modules_list: List,
  308. model_config: Optional[PretrainedConfig],
  309. ) -> bool:
  310. return type(source_layer) is ReplicatedLinear
  311. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  312. """
  313. LoRA on top of ColumnParallelLinear layer.
  314. LoRA B is sliced for tensor parallelism.
  315. """
  316. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  317. super().__init__()
  318. self.base_layer = base_layer
  319. self.tp_size = get_tensor_model_parallel_world_size()
  320. self.input_size = self.base_layer.input_size
  321. self.output_size = self.base_layer.output_size_per_partition
  322. self.device = _get_lora_device(self.base_layer)
  323. def create_lora_weights(
  324. self,
  325. max_loras: int,
  326. lora_config: LoRAConfig,
  327. model_config: Optional[PretrainedConfig] = None,
  328. ) -> None:
  329. self.lora_config = lora_config
  330. self.tp_size = get_tensor_model_parallel_world_size()
  331. lora_a_output_size_per_partition = (
  332. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  333. else divide(lora_config.max_lora_rank, self.tp_size))
  334. self.lora_a_stacked = torch.zeros(
  335. max_loras,
  336. 1,
  337. lora_a_output_size_per_partition,
  338. self.input_size,
  339. dtype=lora_config.lora_dtype,
  340. device=self.device,
  341. )
  342. self.lora_b_stacked = torch.zeros(
  343. max_loras,
  344. 1,
  345. self.output_size,
  346. lora_config.max_lora_rank,
  347. dtype=lora_config.lora_dtype,
  348. device=self.device,
  349. )
  350. self.output_dim = self.lora_b_stacked.shape[2]
  351. def reset_lora(self, index: int):
  352. self.lora_a_stacked[index] = 0
  353. self.lora_b_stacked[index] = 0
  354. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  355. return lora_a
  356. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  357. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  358. shard_size = self.output_dim
  359. start_idx = tensor_model_parallel_rank * shard_size
  360. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  361. lora_b = lora_b[:, start_idx:end_idx]
  362. return lora_b
  363. def set_lora(
  364. self,
  365. index: int,
  366. lora_a: torch.Tensor,
  367. lora_b: torch.Tensor,
  368. embeddings_tensor: Optional[torch.Tensor],
  369. ):
  370. self.reset_lora(index)
  371. if self.tp_size > 1:
  372. lora_a = self.slice_lora_a(lora_a)
  373. lora_b = self.slice_lora_b(lora_b)
  374. self.lora_a_stacked[index,
  375. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  376. lora_a.T, non_blocking=True)
  377. self.lora_b_stacked[index,
  378. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  379. lora_b.T, non_blocking=True)
  380. def apply(self, x: torch.Tensor,
  381. bias: Optional[torch.Tensor]) -> torch.Tensor:
  382. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  383. self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
  384. self.lora_b_stacked, 1.0)
  385. return output
  386. def forward(self, input_):
  387. """Forward of ColumnParallelLinear
  388. Args:
  389. input_: Tensor whose last dimension is `input_size`.
  390. Returns:
  391. - output
  392. - bias
  393. """
  394. bias = (self.base_layer.bias
  395. if not self.base_layer.skip_bias_add else None)
  396. # Matrix multiply.
  397. output_parallel = self.apply(input_, bias)
  398. if self.base_layer.gather_output:
  399. # All-gather across the partitions.
  400. output = tensor_model_parallel_all_gather(output_parallel)
  401. else:
  402. output = output_parallel
  403. output_bias = (self.base_layer.bias
  404. if self.base_layer.skip_bias_add else None)
  405. return output, output_bias
  406. @classmethod
  407. @_not_fully_sharded_can_replace
  408. def can_replace_layer(
  409. cls,
  410. source_layer: nn.Module,
  411. lora_config: LoRAConfig,
  412. packed_modules_list: List,
  413. model_config: Optional[PretrainedConfig],
  414. ) -> bool:
  415. return type(source_layer) is ColumnParallelLinear or (
  416. type(source_layer) is MergedColumnParallelLinear
  417. and len(packed_modules_list) == 1)
  418. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  419. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  420. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  421. This means we have 2 LoRAs, each applied to one half of the layer.
  422. Both slices must have the same size.
  423. """
  424. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  425. super().__init__(base_layer)
  426. def create_lora_weights(
  427. self,
  428. max_loras: int,
  429. lora_config: LoRAConfig,
  430. model_config: Optional[PretrainedConfig] = None,
  431. ) -> None:
  432. self.lora_config = lora_config
  433. n_slices = 2
  434. if not (len(self.base_layer.output_sizes) == n_slices
  435. and self.base_layer.output_sizes[0]
  436. == self.base_layer.output_sizes[1]):
  437. raise ValueError(
  438. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  439. "the same size.")
  440. self.tp_size = get_tensor_model_parallel_world_size()
  441. self.tp_rank = get_tensor_model_parallel_rank()
  442. lora_a_output_size_per_partition = (
  443. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  444. else divide(lora_config.max_lora_rank, self.tp_size))
  445. self.lora_a_stacked = tuple(
  446. torch.zeros(
  447. max_loras,
  448. 1,
  449. lora_a_output_size_per_partition,
  450. self.input_size,
  451. dtype=lora_config.lora_dtype,
  452. device=self.device,
  453. ) for _ in range(n_slices))
  454. self.lora_b_stacked = tuple(
  455. torch.zeros(
  456. max_loras,
  457. 1,
  458. self.output_size // 2,
  459. lora_config.max_lora_rank,
  460. dtype=lora_config.lora_dtype,
  461. device=self.device,
  462. ) for _ in range(n_slices))
  463. self.output_dim = self.lora_b_stacked[0].shape[2]
  464. def reset_lora(self, index: int):
  465. self.lora_a_stacked[0][index] = 0
  466. self.lora_a_stacked[1][index] = 0
  467. self.lora_b_stacked[0][index] = 0
  468. self.lora_b_stacked[1][index] = 0
  469. def slice_lora_a(
  470. self, lora_a: List[Union[torch.Tensor, None]]
  471. ) -> List[Union[torch.Tensor, None]]:
  472. return lora_a
  473. def slice_lora_b(
  474. self, lora_b: List[Union[torch.Tensor, None]]
  475. ) -> List[Union[torch.Tensor, None]]:
  476. if lora_b[0] is None or lora_b[1] is None:
  477. return lora_b
  478. shard_size = self.output_dim
  479. start_idx = self.tp_rank * shard_size
  480. end_idx = (self.tp_rank + 1) * shard_size
  481. lora_b = [
  482. lora_b[0][:, start_idx:end_idx],
  483. lora_b[1][:, start_idx:end_idx],
  484. ]
  485. return lora_b
  486. def set_lora(
  487. self,
  488. index: int,
  489. lora_a: torch.Tensor,
  490. lora_b: torch.Tensor,
  491. embeddings_tensor: Optional[torch.Tensor],
  492. ):
  493. self.reset_lora(index)
  494. if self.tp_size > 1:
  495. lora_a = self.slice_lora_a(lora_a)
  496. lora_b = self.slice_lora_b(lora_b)
  497. if lora_a[0] is not None:
  498. self.lora_a_stacked[0][
  499. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  500. lora_a[0].T, non_blocking=True)
  501. self.lora_b_stacked[0][
  502. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  503. lora_b[0].T, non_blocking=True)
  504. if lora_a[1] is not None:
  505. self.lora_a_stacked[1][
  506. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  507. lora_a[1].T, non_blocking=True)
  508. self.lora_b_stacked[1][
  509. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  510. lora_b[1].T, non_blocking=True)
  511. def apply(self, x: torch.Tensor,
  512. bias: Optional[torch.Tensor]) -> torch.Tensor:
  513. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  514. self.punica_wrapper.add_lora_packed_nslice(
  515. output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
  516. (self.output_dim, self.output_dim))
  517. return output
  518. @classmethod
  519. @_not_fully_sharded_can_replace
  520. def can_replace_layer(
  521. cls,
  522. source_layer: nn.Module,
  523. lora_config: LoRAConfig,
  524. packed_modules_list: List,
  525. model_config: Optional[PretrainedConfig],
  526. ) -> bool:
  527. return (type(source_layer) is MergedColumnParallelLinear
  528. and len(packed_modules_list) == 2)
  529. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  530. """
  531. ColumnParallelLinear layer that is specifically designed for
  532. qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
  533. only contains a single LoRA within their qkv_proj layer.
  534. During inference with Tensor Parallel, the weights of lora_b
  535. must be accurately partitioned according to the respective ranks.
  536. Q slice may have different shape than K and V slices (which both have
  537. the same shape).
  538. """
  539. def __init__(self, base_layer: QKVParallelLinear) -> None:
  540. super().__init__(base_layer)
  541. self.tp_size = get_tensor_model_parallel_world_size()
  542. self.q_proj_total_size = (self.base_layer.total_num_heads *
  543. self.base_layer.head_size)
  544. self.q_proj_shard_size = (self.base_layer.num_heads *
  545. self.base_layer.head_size)
  546. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  547. self.base_layer.head_size)
  548. self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
  549. self.base_layer.head_size)
  550. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  551. tp_rank = get_tensor_model_parallel_rank()
  552. self.q_shard_id = tp_rank
  553. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  554. lora_b_q = lora_b[:, self.q_proj_shard_size *
  555. self.q_shard_id:self.q_proj_shard_size *
  556. (self.q_shard_id + 1)]
  557. k_offset = self.q_proj_total_size
  558. lora_b_k = lora_b[:, k_offset +
  559. self.kv_proj_shard_size * self.kv_shard_id:k_offset +
  560. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  561. v_offset = k_offset + self.kv_proj_total_size
  562. lora_b_v = lora_b[:, v_offset +
  563. self.kv_proj_shard_size * self.kv_shard_id:v_offset +
  564. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  565. lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
  566. return lora_b
  567. def set_lora(
  568. self,
  569. index: int,
  570. lora_a: torch.Tensor,
  571. lora_b: torch.Tensor,
  572. embeddings_tensor: Optional[torch.Tensor],
  573. ):
  574. self.reset_lora(index)
  575. if self.tp_size > 1:
  576. lora_a = self.slice_lora_a(lora_a)
  577. lora_b = self.slice_lora_b(lora_b)
  578. self.lora_a_stacked[index,
  579. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  580. lora_a.T, non_blocking=True)
  581. self.lora_b_stacked[index,
  582. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  583. lora_b.T, non_blocking=True)
  584. @classmethod
  585. @_not_fully_sharded_can_replace
  586. def can_replace_layer(cls, source_layer: nn.Module,
  587. lora_config: LoRAConfig, packed_modules_list: List,
  588. model_config: Optional[PretrainedConfig]) -> bool:
  589. return type(source_layer) is QKVParallelLinear and len(
  590. packed_modules_list) == 1
  591. class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  592. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  593. packed together in qkv proj fashion
  594. (q_proj + k_proj + v_proj -> qkv_proj).
  595. This means we have 3 LoRAs, each applied to one slice of the layer.
  596. Q slice may have different shape than K and V slices (which both have
  597. the same shape).
  598. """
  599. def __init__(self, base_layer: QKVParallelLinear) -> None:
  600. super().__init__(base_layer)
  601. def create_lora_weights(
  602. self,
  603. max_loras: int,
  604. lora_config: LoRAConfig,
  605. model_config: Optional[PretrainedConfig] = None,
  606. ) -> None:
  607. self.lora_config = lora_config
  608. self.tp_size = get_tensor_model_parallel_world_size()
  609. self.tp_rank = get_tensor_model_parallel_rank()
  610. self.q_proj_shard_size = (self.base_layer.num_heads *
  611. self.base_layer.head_size)
  612. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  613. self.base_layer.head_size)
  614. self.q_shard_id = self.tp_rank
  615. self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
  616. lora_a_output_size_per_partition = (
  617. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  618. else divide(lora_config.max_lora_rank, self.tp_size))
  619. # q, k, v
  620. self.lora_a_stacked = (
  621. torch.zeros(
  622. max_loras,
  623. 1,
  624. lora_a_output_size_per_partition,
  625. self.input_size,
  626. dtype=lora_config.lora_dtype,
  627. device=self.device,
  628. ),
  629. torch.zeros(
  630. max_loras,
  631. 1,
  632. lora_a_output_size_per_partition,
  633. self.input_size,
  634. dtype=lora_config.lora_dtype,
  635. device=self.device,
  636. ),
  637. torch.zeros(
  638. max_loras,
  639. 1,
  640. lora_a_output_size_per_partition,
  641. self.input_size,
  642. dtype=lora_config.lora_dtype,
  643. device=self.device,
  644. ),
  645. )
  646. self.lora_b_stacked = (
  647. torch.zeros(
  648. max_loras,
  649. 1,
  650. self.q_proj_shard_size,
  651. lora_config.max_lora_rank,
  652. dtype=lora_config.lora_dtype,
  653. device=self.device,
  654. ),
  655. torch.zeros(
  656. max_loras,
  657. 1,
  658. self.kv_proj_shard_size,
  659. lora_config.max_lora_rank,
  660. dtype=lora_config.lora_dtype,
  661. device=self.device,
  662. ),
  663. torch.zeros(
  664. max_loras,
  665. 1,
  666. self.kv_proj_shard_size,
  667. lora_config.max_lora_rank,
  668. dtype=lora_config.lora_dtype,
  669. device=self.device,
  670. ),
  671. )
  672. self.output_slices = (
  673. self.q_proj_shard_size,
  674. self.kv_proj_shard_size,
  675. self.kv_proj_shard_size,
  676. )
  677. self.packed_indices: Optional[torch.Tensor] = None
  678. self.standard_indices: Optional[torch.Tensor] = None
  679. # lazily initialized.
  680. self.indices: torch.Tensor
  681. self.indices_len: List[int]
  682. def reset_lora(self, index: int):
  683. self.lora_a_stacked[0][index] = 0
  684. self.lora_b_stacked[0][index] = 0
  685. self.lora_a_stacked[1][index] = 0
  686. self.lora_b_stacked[1][index] = 0
  687. self.lora_a_stacked[2][index] = 0
  688. self.lora_b_stacked[2][index] = 0
  689. def slice_lora_a(
  690. self, lora_a: List[Union[torch.Tensor, None]]
  691. ) -> List[Union[torch.Tensor, None]]:
  692. return lora_a
  693. def slice_lora_b(
  694. self, lora_b: List[Union[torch.Tensor, None]]
  695. ) -> List[Union[torch.Tensor, None]]:
  696. lora_b_q, lora_b_k, lora_b_v = None, None, None
  697. if lora_b[0] is not None:
  698. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  699. self.q_shard_id:self.q_proj_shard_size *
  700. (self.q_shard_id + 1), ]
  701. if lora_b[1] is not None:
  702. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  703. self.kv_shard_id:self.kv_proj_shard_size *
  704. (self.kv_shard_id + 1), ]
  705. if lora_b[2] is not None:
  706. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  707. self.kv_shard_id:self.kv_proj_shard_size *
  708. (self.kv_shard_id + 1), ]
  709. lora_b = [lora_b_q, lora_b_k, lora_b_v]
  710. return lora_b
  711. def set_lora(
  712. self,
  713. index: int,
  714. lora_a: torch.Tensor,
  715. lora_b: torch.Tensor,
  716. embeddings_tensor: Optional[torch.Tensor],
  717. ):
  718. self.reset_lora(index)
  719. if self.tp_size > 1:
  720. lora_a = self.slice_lora_a(lora_a)
  721. lora_b = self.slice_lora_b(lora_b)
  722. if lora_b[0] is not None:
  723. lora_b_q = lora_b[0]
  724. self.lora_b_stacked[0][
  725. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  726. lora_b_q.T, non_blocking=True)
  727. if lora_b[1] is not None:
  728. lora_b_k = lora_b[1]
  729. self.lora_b_stacked[1][
  730. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  731. lora_b_k.T, non_blocking=True)
  732. if lora_b[2] is not None:
  733. lora_b_v = lora_b[2]
  734. self.lora_b_stacked[2][
  735. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  736. lora_b_v.T, non_blocking=True)
  737. if lora_a[0] is not None:
  738. self.lora_a_stacked[0][
  739. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  740. lora_a[0].T, non_blocking=True)
  741. if lora_a[1] is not None:
  742. self.lora_a_stacked[1][
  743. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  744. lora_a[1].T, non_blocking=True)
  745. if lora_a[2] is not None:
  746. self.lora_a_stacked[2][
  747. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  748. lora_a[2].T, non_blocking=True)
  749. def apply(self, x: torch.Tensor,
  750. bias: Optional[torch.Tensor]) -> torch.Tensor:
  751. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  752. self.punica_wrapper.add_lora_packed_nslice(output, x,
  753. self.lora_a_stacked,
  754. self.lora_b_stacked, 1.0,
  755. self.output_slices)
  756. return output
  757. @classmethod
  758. @_not_fully_sharded_can_replace
  759. def can_replace_layer(
  760. cls,
  761. source_layer: nn.Module,
  762. lora_config: LoRAConfig,
  763. packed_modules_list: List,
  764. model_config: Optional[PretrainedConfig],
  765. ) -> bool:
  766. return (type(source_layer) is QKVParallelLinear
  767. and len(packed_modules_list) == 3)
  768. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  769. def __init__(self, base_layer: RowParallelLinear) -> None:
  770. super().__init__()
  771. self.base_layer = base_layer
  772. self.input_size = self.base_layer.input_size_per_partition
  773. self.output_size = self.base_layer.output_size
  774. self.device = _get_lora_device(self.base_layer)
  775. def create_lora_weights(
  776. self,
  777. max_loras: int,
  778. lora_config: LoRAConfig,
  779. model_config: Optional[PretrainedConfig] = None,
  780. ) -> None:
  781. self.lora_config = lora_config
  782. self.tp_rank = get_tensor_model_parallel_rank()
  783. self.lora_a_stacked = torch.zeros(
  784. (
  785. max_loras,
  786. 1,
  787. lora_config.max_lora_rank,
  788. self.input_size,
  789. ),
  790. dtype=lora_config.lora_dtype,
  791. device=self.device,
  792. )
  793. tp_size = get_tensor_model_parallel_world_size()
  794. lora_b_output_size_per_partition = (
  795. self.output_size if not lora_config.fully_sharded_loras else
  796. divide(self.output_size, tp_size))
  797. self.lora_b_stacked = torch.zeros(
  798. (
  799. max_loras,
  800. 1,
  801. lora_b_output_size_per_partition,
  802. lora_config.max_lora_rank,
  803. ),
  804. dtype=lora_config.lora_dtype,
  805. device=self.device,
  806. )
  807. def reset_lora(self, index: int):
  808. self.lora_a_stacked[index] = 0
  809. self.lora_b_stacked[index] = 0
  810. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  811. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  812. shard_size = self.input_size
  813. start_idx = tensor_model_parallel_rank * shard_size
  814. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  815. lora_a = lora_a[start_idx:end_idx, :]
  816. return lora_a
  817. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  818. return lora_b
  819. def set_lora(
  820. self,
  821. index: int,
  822. lora_a: torch.Tensor,
  823. lora_b: torch.Tensor,
  824. embeddings_tensor: Optional[torch.Tensor],
  825. ):
  826. self.reset_lora(index)
  827. if self.base_layer.tp_size > 1:
  828. lora_a = self.slice_lora_a(lora_a)
  829. lora_b = self.slice_lora_b(lora_b)
  830. self.lora_a_stacked[index,
  831. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  832. lora_a.T, non_blocking=True)
  833. self.lora_b_stacked[index,
  834. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  835. lora_b.T, non_blocking=True)
  836. def apply(self, x: torch.Tensor) -> torch.Tensor:
  837. output = self.base_layer.quant_method.apply(self.base_layer, x)
  838. self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
  839. self.lora_b_stacked, 1.0)
  840. return output
  841. def forward(self, input_):
  842. """Forward of RowParallelLinear
  843. Args:
  844. input_: tensor whose last dimension is `input_size`. If
  845. `input_is_parallel` is set, then the last dimension
  846. is `input_size // tp_size`.
  847. Returns:
  848. - output
  849. - bias
  850. """
  851. # Set up backprop all-reduce.
  852. if self.base_layer.input_is_parallel:
  853. input_parallel = input_
  854. else:
  855. # TODO: simplify code below
  856. tp_rank = get_tensor_model_parallel_rank()
  857. splitted_input = split_tensor_along_last_dim(
  858. input_, num_partitions=self.base_layer.tp_size)
  859. input_parallel = splitted_input[tp_rank].contiguous()
  860. # Matrix multiply.
  861. output_parallel = self.apply(input_parallel)
  862. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  863. output_ = tensor_model_parallel_all_reduce(output_parallel)
  864. else:
  865. output_ = output_parallel
  866. if not self.base_layer.skip_bias_add:
  867. output = (output_ + self.base_layer.bias
  868. if self.base_layer.bias is not None else output_)
  869. output_bias = None
  870. else:
  871. output = output_
  872. output_bias = self.base_layer.bias
  873. return output, output_bias
  874. @property
  875. def weight(self):
  876. return (self.base_layer.weight if hasattr(self.base_layer, "weight")
  877. else self.base_layer.qweight)
  878. @classmethod
  879. @_not_fully_sharded_can_replace
  880. def can_replace_layer(
  881. cls,
  882. source_layer: nn.Module,
  883. lora_config: LoRAConfig,
  884. packed_modules_list: List,
  885. model_config: Optional[PretrainedConfig],
  886. ) -> bool:
  887. return type(source_layer) is RowParallelLinear
  888. class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
  889. """
  890. LoRA wrapper for LogitsProcessor, with extra logic to handle the
  891. application of the LoRA adapter and added LoRA vocabulary.
  892. Args:
  893. base_layer: LogitsProcessor layer
  894. hidden_size: hidden size of the model
  895. dtype: data type of the model
  896. device: device of the model
  897. sharded_to_full_mapping: index mapping from sharded vocab to full vocab
  898. received from base_layer.get_sharded_to_full_mapping(). If None,
  899. no reindexing will be done.
  900. """
  901. def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
  902. dtype: torch.dtype, device: torch.device,
  903. sharded_to_full_mapping: Optional[List[int]]) -> None:
  904. super().__init__()
  905. self.base_layer = base_layer
  906. self.hidden_size = hidden_size
  907. self.dtype = dtype
  908. self.device = device
  909. self.tp_size = get_tensor_model_parallel_world_size()
  910. self.tp_rank = get_tensor_model_parallel_rank()
  911. self.sharded_to_full_mapping = sharded_to_full_mapping
  912. @property
  913. def logits_as_input(self):
  914. return self.base_layer.logits_as_input
  915. @property
  916. def vocab_size(self):
  917. return self.base_layer.vocab_size
  918. @property
  919. def scale(self):
  920. return self.base_layer.scale
  921. @property
  922. def soft_cap(self):
  923. return self.base_layer.soft_cap
  924. @property
  925. def use_gather(self):
  926. return self.base_layer.use_gather
  927. @property
  928. def org_vocab_size(self):
  929. return self.base_layer.org_vocab_size
  930. @property
  931. def include_gpu_probs_tensor(self):
  932. return self.base_layer.include_gpu_probs_tensor
  933. @property
  934. def should_modify_greedy_probs_inplace(self):
  935. return self.base_layer.should_modify_greedy_probs_inplace
  936. def create_lora_weights(
  937. self,
  938. max_loras: int,
  939. lora_config: LoRAConfig,
  940. model_config: Optional[PretrainedConfig] = None,
  941. ) -> None:
  942. # TODO: Verify if this condition can be further relaxed
  943. if 32000 < self.base_layer.vocab_size > 257024:
  944. raise ValueError("When using LoRA, vocab size must be "
  945. "32000 >= vocab_size <= 257024, "
  946. f"but got {self.base_layer.vocab_size}.")
  947. self.lora_a_stacked = torch.zeros(
  948. (
  949. max_loras,
  950. 1,
  951. lora_config.max_lora_rank,
  952. self.hidden_size,
  953. ),
  954. dtype=lora_config.lora_dtype,
  955. device=self.device,
  956. )
  957. self.lora_b_stacked = torch.zeros(
  958. (
  959. max_loras,
  960. 1,
  961. # Pad for kernel compatibility
  962. math.ceil(self.base_layer.vocab_size /
  963. lora_config.lora_vocab_padding_size) *
  964. lora_config.lora_vocab_padding_size,
  965. lora_config.max_lora_rank,
  966. ),
  967. dtype=lora_config.lora_dtype,
  968. device=self.device,
  969. )
  970. self.embeddings_tensors = torch.full(
  971. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  972. fill_value=float("-inf"),
  973. dtype=self.dtype,
  974. device=self.device,
  975. )
  976. if self.sharded_to_full_mapping is not None:
  977. self.sharded_to_full_mapping_gpu = torch.tensor(
  978. self.sharded_to_full_mapping,
  979. device=self.device,
  980. dtype=torch.long)
  981. else:
  982. self.sharded_to_full_mapping_gpu = None
  983. def reset_lora(self, index: int):
  984. self.lora_a_stacked[index] = 0
  985. self.lora_b_stacked[index] = 0
  986. self.embeddings_tensors[index] = float("-inf")
  987. def set_lora(
  988. self,
  989. index: int,
  990. lora_a: torch.Tensor,
  991. lora_b: torch.Tensor,
  992. embeddings_tensor: Optional[torch.Tensor],
  993. ):
  994. self.reset_lora(index)
  995. self.lora_a_stacked[index,
  996. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  997. lora_a.T, non_blocking=True)
  998. self.lora_b_stacked[index,
  999. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  1000. lora_b.T, non_blocking=True)
  1001. if embeddings_tensor is not None:
  1002. self.embeddings_tensors[
  1003. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  1004. shape[1], ] = embeddings_tensor
  1005. def _get_logits(
  1006. self,
  1007. hidden_states: torch.Tensor,
  1008. lm_head: VocabParallelEmbedding,
  1009. embedding_bias: Optional[torch.Tensor] = None,
  1010. ) -> Optional[torch.Tensor]:
  1011. # Get the logits for the next tokens.
  1012. logits = lm_head.linear_method.apply(lm_head, hidden_states)
  1013. if embedding_bias is not None:
  1014. logits += embedding_bias
  1015. logits = tensor_model_parallel_gather(logits)
  1016. if logits is None:
  1017. return None
  1018. if self.sharded_to_full_mapping_gpu is not None:
  1019. # Reindex full logits tensor to ensure 1:1 mapping between
  1020. # index and token_id
  1021. # Example for:
  1022. # org_vocab_size = 4
  1023. # added_vocab_size = 2
  1024. # pad_to_size = 8
  1025. # tp_size = 2
  1026. # indices: [0, 1, 2, 3, 4, 5, 6, 7]
  1027. # token_id: [0, 1, 4, -1, 2, 3, 5, -1]
  1028. # Therefore, the mapping is expected to be:
  1029. # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
  1030. # we get:
  1031. # indices: [0, 1, 2, 3, 4, 5, 6, 7]
  1032. # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
  1033. logits = logits[:, self.sharded_to_full_mapping_gpu]
  1034. lora_logits = torch.empty(
  1035. self.embeddings_tensors.shape[0] + 1,
  1036. self.embeddings_tensors.shape[1],
  1037. hidden_states.shape[0],
  1038. dtype=self.embeddings_tensors.dtype,
  1039. device=self.embeddings_tensors.device,
  1040. )
  1041. torch.matmul(self.embeddings_tensors,
  1042. hidden_states.T,
  1043. out=lora_logits[:-1])
  1044. lora_logits[-1] = float("-inf")
  1045. lora_logits = lora_logits.mT
  1046. indices_padded = self.punica_wrapper.sampler_indices_padded
  1047. lora_logits = (lora_logits.reshape(
  1048. lora_logits.shape[0] * lora_logits.shape[1],
  1049. lora_logits.shape[2],
  1050. ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
  1051. posinf=float("inf"),
  1052. neginf=float("-inf")))
  1053. logits[:,
  1054. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  1055. lora_logits.shape[1], ] = lora_logits
  1056. # LogitsProcessorWithLoRA always using bgmv
  1057. self.punica_wrapper.add_lora_logits(logits, hidden_states,
  1058. self.lora_a_stacked,
  1059. self.lora_b_stacked, 1.0)
  1060. # Remove paddings in vocab (if any).
  1061. logits = logits[:, :self.base_layer.vocab_size]
  1062. return logits
  1063. def forward(self, *args, **kwargs):
  1064. return type(self.base_layer).forward(self, *args, **kwargs)
  1065. @classmethod
  1066. def can_replace_layer(
  1067. cls,
  1068. source_layer: nn.Module,
  1069. lora_config: LoRAConfig,
  1070. packed_modules_list: List,
  1071. model_config: Optional[PretrainedConfig],
  1072. ) -> bool:
  1073. # Special handling for the LogitsProcessor.
  1074. return False
  1075. class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
  1076. """Implements RoPE-scaled embeddings with linear scaling for
  1077. multiple LoRA adapters with a specialized kernel.
  1078. Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
  1079. which can handle multi lora adapters in a specialied kernel.
  1080. """
  1081. def __init__(self, base_layer: RotaryEmbedding) -> None:
  1082. super().__init__()
  1083. self.base_layer = base_layer
  1084. @property
  1085. def scaling_factors(self):
  1086. return self.base_layer.scaling_factors
  1087. @property
  1088. def rotary_dim(self):
  1089. return self.base_layer.rotary_dim
  1090. def create_lora_weights(
  1091. self,
  1092. max_loras: int,
  1093. lora_config: LoRAConfig,
  1094. model_config: Optional[PretrainedConfig] = None,
  1095. ) -> None:
  1096. scaling_factors = (list(lora_config.long_lora_scaling_factors)
  1097. if lora_config.long_lora_scaling_factors else [])
  1098. base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
  1099. self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
  1100. scaling_factors = sorted(
  1101. list(set([base_scaling_factor] + scaling_factors)))
  1102. self.base_layer = LinearScalingRotaryEmbedding(
  1103. self.base_layer.head_size,
  1104. self.base_layer.rotary_dim,
  1105. self.base_layer.max_position_embeddings,
  1106. self.base_layer.base,
  1107. self.base_layer.is_neox_style,
  1108. scaling_factors,
  1109. self.base_layer.dtype,
  1110. )
  1111. def reset_lora(self, index: int):
  1112. ...
  1113. def set_lora(
  1114. self,
  1115. index: int,
  1116. lora_a: torch.Tensor,
  1117. lora_b: torch.Tensor,
  1118. embeddings_tensor: Optional[torch.Tensor],
  1119. ):
  1120. ...
  1121. def forward(
  1122. self,
  1123. positions: torch.Tensor,
  1124. query: torch.Tensor,
  1125. key: torch.Tensor,
  1126. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1127. return self.base_layer(
  1128. positions,
  1129. query,
  1130. key,
  1131. offsets=self.punica_wrapper.long_lora_indices,
  1132. )
  1133. @property
  1134. def scaling_factor_to_offset(self) -> Dict[float, int]:
  1135. return self.base_layer.scaling_factor_to_offset
  1136. @classmethod
  1137. def can_replace_layer(
  1138. cls,
  1139. source_layer: nn.Module,
  1140. lora_config: LoRAConfig,
  1141. packed_modules_list: List,
  1142. model_config: Optional[PretrainedConfig],
  1143. ) -> bool:
  1144. """Returns True if the layer can be replaced by this LoRA layer."""
  1145. return (type(source_layer) is LinearScalingRotaryEmbedding
  1146. or type(source_layer) is RotaryEmbedding)
  1147. def extra_repr(self) -> str:
  1148. return self.base_layer.extra_repr()