layers.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307
  1. # pylint: disable=unused-argument
  2. import math
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from transformers import PretrainedConfig
  9. from aphrodite.adapter_commons.layers import AdapterMapping
  10. from aphrodite.common.config import LoRAConfig
  11. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size,
  13. split_tensor_along_last_dim,
  14. tensor_model_parallel_all_gather,
  15. tensor_model_parallel_all_reduce,
  16. tensor_model_parallel_gather)
  17. from aphrodite.distributed.utils import divide
  18. from aphrodite.lora.punica import PunicaWrapper
  19. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  20. MergedColumnParallelLinear,
  21. QKVParallelLinear,
  22. ReplicatedLinear,
  23. RowParallelLinear)
  24. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  25. from aphrodite.modeling.layers.rotary_embedding import (
  26. LinearScalingRotaryEmbedding, RotaryEmbedding)
  27. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  28. VocabParallelEmbedding)
  29. if TYPE_CHECKING:
  30. pass
  31. def _get_lora_device(base_layer: nn.Module) -> torch.device:
  32. # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
  33. """Returns the device for where to place the LoRA tensors."""
  34. # unquantizedLinear
  35. if hasattr(base_layer, "weight"):
  36. return base_layer.weight.device
  37. # GPTQ/AWQ/SqueezeLLM
  38. elif hasattr(base_layer, "qweight"):
  39. return base_layer.qweight.device
  40. # marlin
  41. elif hasattr(base_layer, "B"):
  42. return base_layer.B.device
  43. else:
  44. raise ValueError(f"Unsupported base layer: {base_layer}")
  45. def _not_fully_sharded_can_replace(can_replace):
  46. """
  47. decorator which adds the condition of not using fully sharded loras
  48. intended to wrap can_replace_layer()
  49. """
  50. def dec(*args, **kwargs):
  51. decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
  52. condition = (not kwargs["lora_config"].fully_sharded_loras
  53. if decorate else True)
  54. return can_replace(*args, **kwargs) and condition
  55. return dec
  56. @dataclass
  57. class LoRAMapping(AdapterMapping):
  58. is_prefill: bool = False
  59. class BaseLayerWithLoRA(nn.Module):
  60. def slice_lora_a(
  61. self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  62. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  63. """Slice lora a if splitting for tensor parallelism."""
  64. ...
  65. def slice_lora_b(
  66. self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  67. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  68. """Slice lora b if splitting with tensor parallelism."""
  69. ...
  70. def create_lora_weights(
  71. self,
  72. max_loras: int,
  73. lora_config: LoRAConfig,
  74. model_config: Optional[PretrainedConfig] = None,
  75. ) -> None:
  76. """Initializes lora matrices."""
  77. ...
  78. def reset_lora(self, index: int):
  79. """Resets the lora weights at index back to 0."""
  80. ...
  81. def set_lora(
  82. self,
  83. index: int,
  84. lora_a: torch.Tensor,
  85. lora_b: torch.Tensor,
  86. embeddings_tensor: Optional[torch.Tensor],
  87. ):
  88. """Overwrites lora tensors at index."""
  89. ...
  90. def set_mapping(
  91. self,
  92. punica_wrapper: PunicaWrapper,
  93. ):
  94. self.punica_wrapper: PunicaWrapper = punica_wrapper
  95. @classmethod
  96. def can_replace_layer(
  97. cls,
  98. source_layer: nn.Module,
  99. lora_config: LoRAConfig,
  100. packed_modules_list: List,
  101. model_config: Optional[PretrainedConfig],
  102. ) -> bool:
  103. """Returns True if the layer can be replaced by this LoRA layer."""
  104. raise NotImplementedError
  105. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
  106. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  107. super().__init__()
  108. self.base_layer = base_layer
  109. self.embeddings_slice: Optional[Tuple[int, int]]
  110. self.embeddings_weights: Optional[torch.Tensor]
  111. def create_lora_weights(
  112. self,
  113. max_loras: int,
  114. lora_config: LoRAConfig,
  115. model_config: Optional[PretrainedConfig] = None) -> None:
  116. if self.base_layer.num_added_embeddings_per_partition > 0:
  117. # We can start adding lora weights
  118. self.embeddings_weights = self.base_layer.weight.data[
  119. self.base_layer.num_org_embeddings_per_partition:self.
  120. base_layer.num_org_embeddings_per_partition +
  121. self.base_layer.num_added_embeddings_per_partition]
  122. self.embeddings_slice = (
  123. self.base_layer.shard_indices.added_vocab_start_index -
  124. self.base_layer.org_vocab_size,
  125. self.base_layer.shard_indices.added_vocab_end_index -
  126. self.base_layer.org_vocab_size)
  127. self.base_layer.weight.data[
  128. self.base_layer.num_org_embeddings_per_partition:].fill_(0)
  129. else:
  130. self.embeddings_slice = None
  131. self.embeddings_weights = None
  132. self.embeddings_tensors = torch.zeros(
  133. (
  134. max_loras,
  135. lora_config.lora_extra_vocab_size,
  136. self.base_layer.embedding_dim,
  137. ),
  138. dtype=self.base_layer.weight.dtype,
  139. device=self.base_layer.weight.device,
  140. )
  141. self.lora_a_stacked = torch.zeros(
  142. (
  143. max_loras,
  144. self.base_layer.org_vocab_size +
  145. lora_config.lora_extra_vocab_size,
  146. lora_config.max_lora_rank,
  147. ),
  148. dtype=lora_config.lora_dtype,
  149. device=self.base_layer.weight.device,
  150. )
  151. self.lora_b_stacked = torch.zeros(
  152. (
  153. max_loras,
  154. 1,
  155. self.base_layer.embedding_dim,
  156. lora_config.max_lora_rank,
  157. ),
  158. dtype=lora_config.lora_dtype,
  159. device=self.base_layer.weight.device,
  160. )
  161. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  162. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  163. self.lora_a_stacked.shape[2],
  164. )
  165. def reset_lora(self, index: int):
  166. self.lora_a_stacked[index] = 0
  167. self.lora_b_stacked[index] = 0
  168. self.embeddings_tensors[index] = 0
  169. def set_lora(
  170. self,
  171. index: int,
  172. lora_a: torch.Tensor,
  173. lora_b: torch.Tensor,
  174. embeddings_tensor: Optional[torch.Tensor],
  175. ):
  176. self.reset_lora(index)
  177. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  178. lora_a, non_blocking=True)
  179. self.lora_b_stacked[index,
  180. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  181. lora_b.T, non_blocking=True)
  182. if embeddings_tensor is not None:
  183. self.embeddings_tensors[
  184. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  185. shape[1], ].copy_(embeddings_tensor, non_blocking=True)
  186. if self.embeddings_slice is not None:
  187. # TODO(yard1): Optimize this copy, we don't need to copy
  188. # everything, just the modified part
  189. embeddings = self.embeddings_tensors.view(
  190. self.embeddings_tensors.shape[0] *
  191. self.embeddings_tensors.shape[1],
  192. self.embeddings_tensors.shape[2],
  193. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  194. assert self.embeddings_weights is not None
  195. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  196. def forward(self, x: torch.Tensor) -> torch.Tensor:
  197. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  198. embeddings_indices = self.punica_wrapper.embeddings_indices
  199. indices = embeddings_indices[1].view_as(x)
  200. full_lora_a_embeddings = F.embedding(
  201. x + indices,
  202. self.lora_a_stacked_2d,
  203. )
  204. indices = embeddings_indices[0].view_as(x)
  205. full_output = self.base_layer.forward(
  206. x.add_(indices * added_tokens_mask))
  207. full_output_org = full_output
  208. if full_output.ndim == 3:
  209. full_output = full_output.view(
  210. full_output.shape[0] * full_output.shape[1], -1)
  211. if full_lora_a_embeddings.ndim == 3:
  212. full_lora_a_embeddings = full_lora_a_embeddings.view(
  213. full_lora_a_embeddings.shape[0] *
  214. full_lora_a_embeddings.shape[1],
  215. -1,
  216. )
  217. # Embedding layer only need expand op
  218. self.punica_wrapper.add_expand(full_output,
  219. full_lora_a_embeddings,
  220. self.lora_b_stacked,
  221. add_input=True)
  222. return full_output.view_as(full_output_org)
  223. @classmethod
  224. def can_replace_layer(
  225. cls,
  226. source_layer: nn.Module,
  227. lora_config: LoRAConfig,
  228. packed_modules_list: List,
  229. model_config: Optional[PretrainedConfig],
  230. ) -> bool:
  231. return type(source_layer) is VocabParallelEmbedding
  232. class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
  233. def __init__(self, base_layer: ReplicatedLinear) -> None:
  234. super().__init__()
  235. self.base_layer = base_layer
  236. self.input_size = self.base_layer.input_size
  237. self.output_size = self.base_layer.output_size
  238. self.device = _get_lora_device(self.base_layer)
  239. def create_lora_weights(
  240. self,
  241. max_loras: int,
  242. lora_config: LoRAConfig,
  243. model_config: Optional[PretrainedConfig] = None,
  244. ) -> None:
  245. self.lora_config = lora_config
  246. lora_a_output_size = lora_config.max_lora_rank
  247. self.lora_a_stacked = torch.zeros(
  248. max_loras,
  249. 1,
  250. lora_a_output_size,
  251. self.input_size,
  252. dtype=lora_config.lora_dtype,
  253. device=self.device,
  254. )
  255. self.lora_b_stacked = torch.zeros(
  256. max_loras,
  257. 1,
  258. self.output_size,
  259. lora_config.max_lora_rank,
  260. dtype=lora_config.lora_dtype,
  261. device=self.device,
  262. )
  263. def reset_lora(self, index: int):
  264. self.lora_a_stacked[index] = 0
  265. self.lora_b_stacked[index] = 0
  266. def set_lora(
  267. self,
  268. index: int,
  269. lora_a: torch.Tensor,
  270. lora_b: torch.Tensor,
  271. embeddings_tensor: Optional[torch.Tensor],
  272. ):
  273. self.reset_lora(index)
  274. self.lora_a_stacked[index,
  275. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  276. lora_a.T, non_blocking=True)
  277. self.lora_b_stacked[index,
  278. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  279. lora_b.T, non_blocking=True)
  280. def apply(self, x: torch.Tensor,
  281. bias: Optional[torch.Tensor]) -> torch.Tensor:
  282. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  283. self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
  284. self.lora_b_stacked, 1.0)
  285. return output
  286. def forward(self, input_):
  287. """Forward of ReplicatedLinearWithLoRA
  288. Args:
  289. input_: Tensor whose last dimension is `input_size`.
  290. Returns:
  291. - output
  292. - bias
  293. """
  294. bias = (self.base_layer.bias
  295. if not self.base_layer.skip_bias_add else None)
  296. # Matrix multiply.
  297. output = self.apply(input_, bias)
  298. output_bias = (self.base_layer.bias
  299. if self.base_layer.skip_bias_add else None)
  300. return output, output_bias
  301. @classmethod
  302. @_not_fully_sharded_can_replace
  303. def can_replace_layer(
  304. cls,
  305. source_layer: nn.Module,
  306. lora_config: LoRAConfig,
  307. packed_modules_list: List,
  308. model_config: Optional[PretrainedConfig],
  309. ) -> bool:
  310. return type(source_layer) is ReplicatedLinear
  311. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  312. """
  313. LoRA on top of ColumnParallelLinear layer.
  314. LoRA B is sliced for tensor parallelism.
  315. """
  316. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  317. super().__init__()
  318. self.base_layer = base_layer
  319. self.tp_size = get_tensor_model_parallel_world_size()
  320. self.input_size = self.base_layer.input_size
  321. self.output_size = self.base_layer.output_size_per_partition
  322. self.device = _get_lora_device(self.base_layer)
  323. def create_lora_weights(
  324. self,
  325. max_loras: int,
  326. lora_config: LoRAConfig,
  327. model_config: Optional[PretrainedConfig] = None,
  328. ) -> None:
  329. self.lora_config = lora_config
  330. self.tp_size = get_tensor_model_parallel_world_size()
  331. lora_a_output_size_per_partition = (
  332. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  333. else divide(lora_config.max_lora_rank, self.tp_size))
  334. self.lora_a_stacked = torch.zeros(
  335. max_loras,
  336. 1,
  337. lora_a_output_size_per_partition,
  338. self.input_size,
  339. dtype=lora_config.lora_dtype,
  340. device=self.device,
  341. )
  342. self.lora_b_stacked = torch.zeros(
  343. max_loras,
  344. 1,
  345. self.output_size,
  346. lora_config.max_lora_rank,
  347. dtype=lora_config.lora_dtype,
  348. device=self.device,
  349. )
  350. self.output_dim = self.lora_b_stacked.shape[2]
  351. def reset_lora(self, index: int):
  352. self.lora_a_stacked[index] = 0
  353. self.lora_b_stacked[index] = 0
  354. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  355. return lora_a
  356. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  357. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  358. shard_size = self.output_dim
  359. start_idx = tensor_model_parallel_rank * shard_size
  360. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  361. lora_b = lora_b[:, start_idx:end_idx]
  362. return lora_b
  363. def set_lora(
  364. self,
  365. index: int,
  366. lora_a: torch.Tensor,
  367. lora_b: torch.Tensor,
  368. embeddings_tensor: Optional[torch.Tensor],
  369. ):
  370. self.reset_lora(index)
  371. if self.tp_size > 1:
  372. lora_a = self.slice_lora_a(lora_a)
  373. lora_b = self.slice_lora_b(lora_b)
  374. self.lora_a_stacked[index,
  375. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  376. lora_a.T, non_blocking=True)
  377. self.lora_b_stacked[index,
  378. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  379. lora_b.T, non_blocking=True)
  380. def apply(self, x: torch.Tensor,
  381. bias: Optional[torch.Tensor]) -> torch.Tensor:
  382. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  383. self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
  384. self.lora_b_stacked, 1.0)
  385. return output
  386. def forward(self, input_):
  387. """Forward of ColumnParallelLinear
  388. Args:
  389. input_: Tensor whose last dimension is `input_size`.
  390. Returns:
  391. - output
  392. - bias
  393. """
  394. bias = (self.base_layer.bias
  395. if not self.base_layer.skip_bias_add else None)
  396. # Matrix multiply.
  397. output_parallel = self.apply(input_, bias)
  398. if self.base_layer.gather_output:
  399. # All-gather across the partitions.
  400. output = tensor_model_parallel_all_gather(output_parallel)
  401. else:
  402. output = output_parallel
  403. output_bias = (self.base_layer.bias
  404. if self.base_layer.skip_bias_add else None)
  405. return output, output_bias
  406. @classmethod
  407. @_not_fully_sharded_can_replace
  408. def can_replace_layer(
  409. cls,
  410. source_layer: nn.Module,
  411. lora_config: LoRAConfig,
  412. packed_modules_list: List,
  413. model_config: Optional[PretrainedConfig],
  414. ) -> bool:
  415. return type(source_layer) is ColumnParallelLinear or (
  416. type(source_layer) is MergedColumnParallelLinear
  417. and len(packed_modules_list) == 1)
  418. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  419. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  420. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  421. This means we have 2 LoRAs, each applied to one half of the layer.
  422. Both slices must have the same size.
  423. """
  424. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  425. super().__init__(base_layer)
  426. def create_lora_weights(
  427. self,
  428. max_loras: int,
  429. lora_config: LoRAConfig,
  430. model_config: Optional[PretrainedConfig] = None,
  431. ) -> None:
  432. self.lora_config = lora_config
  433. n_slices = 2
  434. if not (len(self.base_layer.output_sizes) == n_slices
  435. and self.base_layer.output_sizes[0]
  436. == self.base_layer.output_sizes[1]):
  437. raise ValueError(
  438. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  439. "the same size.")
  440. self.tp_size = get_tensor_model_parallel_world_size()
  441. self.tp_rank = get_tensor_model_parallel_rank()
  442. lora_a_output_size_per_partition = (
  443. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  444. else divide(lora_config.max_lora_rank, self.tp_size))
  445. self.lora_a_stacked = tuple(
  446. torch.zeros(
  447. max_loras,
  448. 1,
  449. lora_a_output_size_per_partition,
  450. self.input_size,
  451. dtype=lora_config.lora_dtype,
  452. device=self.device,
  453. ) for _ in range(n_slices))
  454. self.lora_b_stacked = tuple(
  455. torch.zeros(
  456. max_loras,
  457. 1,
  458. self.output_size // 2,
  459. lora_config.max_lora_rank,
  460. dtype=lora_config.lora_dtype,
  461. device=self.device,
  462. ) for _ in range(n_slices))
  463. self.output_dim = self.lora_b_stacked[0].shape[2]
  464. def reset_lora(self, index: int):
  465. self.lora_a_stacked[0][index] = 0
  466. self.lora_a_stacked[1][index] = 0
  467. self.lora_b_stacked[0][index] = 0
  468. self.lora_b_stacked[1][index] = 0
  469. def slice_lora_a(
  470. self, lora_a: List[Union[torch.Tensor, None]]
  471. ) -> List[Union[torch.Tensor, None]]:
  472. return lora_a
  473. def slice_lora_b(
  474. self, lora_b: List[Union[torch.Tensor, None]]
  475. ) -> List[Union[torch.Tensor, None]]:
  476. #NOTE: lora_b contains 2 subloras, and each sublora could be None.
  477. shard_size = self.output_dim
  478. start_idx = self.tp_rank * shard_size
  479. end_idx = (self.tp_rank + 1) * shard_size
  480. lora_b = [
  481. lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None,
  482. lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None,
  483. ]
  484. return lora_b
  485. def set_lora(
  486. self,
  487. index: int,
  488. lora_a: torch.Tensor,
  489. lora_b: torch.Tensor,
  490. embeddings_tensor: Optional[torch.Tensor],
  491. ):
  492. self.reset_lora(index)
  493. if self.tp_size > 1:
  494. lora_a = self.slice_lora_a(lora_a)
  495. lora_b = self.slice_lora_b(lora_b)
  496. if lora_a[0] is not None:
  497. self.lora_a_stacked[0][
  498. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  499. lora_a[0].T, non_blocking=True)
  500. self.lora_b_stacked[0][
  501. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  502. lora_b[0].T, non_blocking=True)
  503. if lora_a[1] is not None:
  504. self.lora_a_stacked[1][
  505. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  506. lora_a[1].T, non_blocking=True)
  507. self.lora_b_stacked[1][
  508. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  509. lora_b[1].T, non_blocking=True)
  510. def apply(self, x: torch.Tensor,
  511. bias: Optional[torch.Tensor]) -> torch.Tensor:
  512. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  513. self.punica_wrapper.add_lora_packed_nslice(
  514. output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
  515. (self.output_dim, self.output_dim))
  516. return output
  517. @classmethod
  518. @_not_fully_sharded_can_replace
  519. def can_replace_layer(
  520. cls,
  521. source_layer: nn.Module,
  522. lora_config: LoRAConfig,
  523. packed_modules_list: List,
  524. model_config: Optional[PretrainedConfig],
  525. ) -> bool:
  526. return (type(source_layer) is MergedColumnParallelLinear
  527. and len(packed_modules_list) == 2)
  528. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  529. """
  530. ColumnParallelLinear layer that is specifically designed for
  531. qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
  532. only contains a single LoRA within their qkv_proj layer.
  533. During inference with Tensor Parallel, the weights of lora_b
  534. must be accurately partitioned according to the respective ranks.
  535. Q slice may have different shape than K and V slices (which both have
  536. the same shape).
  537. """
  538. def __init__(self, base_layer: QKVParallelLinear) -> None:
  539. super().__init__(base_layer)
  540. self.tp_size = get_tensor_model_parallel_world_size()
  541. self.q_proj_total_size = (self.base_layer.total_num_heads *
  542. self.base_layer.head_size)
  543. self.q_proj_shard_size = (self.base_layer.num_heads *
  544. self.base_layer.head_size)
  545. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  546. self.base_layer.head_size)
  547. self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
  548. self.base_layer.head_size)
  549. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  550. tp_rank = get_tensor_model_parallel_rank()
  551. self.q_shard_id = tp_rank
  552. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  553. lora_b_q = lora_b[:, self.q_proj_shard_size *
  554. self.q_shard_id:self.q_proj_shard_size *
  555. (self.q_shard_id + 1)]
  556. k_offset = self.q_proj_total_size
  557. lora_b_k = lora_b[:, k_offset +
  558. self.kv_proj_shard_size * self.kv_shard_id:k_offset +
  559. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  560. v_offset = k_offset + self.kv_proj_total_size
  561. lora_b_v = lora_b[:, v_offset +
  562. self.kv_proj_shard_size * self.kv_shard_id:v_offset +
  563. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  564. lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
  565. return lora_b
  566. def set_lora(
  567. self,
  568. index: int,
  569. lora_a: torch.Tensor,
  570. lora_b: torch.Tensor,
  571. embeddings_tensor: Optional[torch.Tensor],
  572. ):
  573. self.reset_lora(index)
  574. if self.tp_size > 1:
  575. lora_a = self.slice_lora_a(lora_a)
  576. lora_b = self.slice_lora_b(lora_b)
  577. self.lora_a_stacked[index,
  578. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  579. lora_a.T, non_blocking=True)
  580. self.lora_b_stacked[index,
  581. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  582. lora_b.T, non_blocking=True)
  583. @classmethod
  584. @_not_fully_sharded_can_replace
  585. def can_replace_layer(cls, source_layer: nn.Module,
  586. lora_config: LoRAConfig, packed_modules_list: List,
  587. model_config: Optional[PretrainedConfig]) -> bool:
  588. return type(source_layer) is QKVParallelLinear and len(
  589. packed_modules_list) == 1
  590. class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  591. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  592. packed together in qkv proj fashion
  593. (q_proj + k_proj + v_proj -> qkv_proj).
  594. This means we have 3 LoRAs, each applied to one slice of the layer.
  595. Q slice may have different shape than K and V slices (which both have
  596. the same shape).
  597. """
  598. def __init__(self, base_layer: QKVParallelLinear) -> None:
  599. super().__init__(base_layer)
  600. def create_lora_weights(
  601. self,
  602. max_loras: int,
  603. lora_config: LoRAConfig,
  604. model_config: Optional[PretrainedConfig] = None,
  605. ) -> None:
  606. self.lora_config = lora_config
  607. self.tp_size = get_tensor_model_parallel_world_size()
  608. self.tp_rank = get_tensor_model_parallel_rank()
  609. self.q_proj_shard_size = (self.base_layer.num_heads *
  610. self.base_layer.head_size)
  611. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  612. self.base_layer.head_size)
  613. self.q_shard_id = self.tp_rank
  614. self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
  615. lora_a_output_size_per_partition = (
  616. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  617. else divide(lora_config.max_lora_rank, self.tp_size))
  618. # q, k, v
  619. self.lora_a_stacked = (
  620. torch.zeros(
  621. max_loras,
  622. 1,
  623. lora_a_output_size_per_partition,
  624. self.input_size,
  625. dtype=lora_config.lora_dtype,
  626. device=self.device,
  627. ),
  628. torch.zeros(
  629. max_loras,
  630. 1,
  631. lora_a_output_size_per_partition,
  632. self.input_size,
  633. dtype=lora_config.lora_dtype,
  634. device=self.device,
  635. ),
  636. torch.zeros(
  637. max_loras,
  638. 1,
  639. lora_a_output_size_per_partition,
  640. self.input_size,
  641. dtype=lora_config.lora_dtype,
  642. device=self.device,
  643. ),
  644. )
  645. self.lora_b_stacked = (
  646. torch.zeros(
  647. max_loras,
  648. 1,
  649. self.q_proj_shard_size,
  650. lora_config.max_lora_rank,
  651. dtype=lora_config.lora_dtype,
  652. device=self.device,
  653. ),
  654. torch.zeros(
  655. max_loras,
  656. 1,
  657. self.kv_proj_shard_size,
  658. lora_config.max_lora_rank,
  659. dtype=lora_config.lora_dtype,
  660. device=self.device,
  661. ),
  662. torch.zeros(
  663. max_loras,
  664. 1,
  665. self.kv_proj_shard_size,
  666. lora_config.max_lora_rank,
  667. dtype=lora_config.lora_dtype,
  668. device=self.device,
  669. ),
  670. )
  671. self.output_slices = (
  672. self.q_proj_shard_size,
  673. self.kv_proj_shard_size,
  674. self.kv_proj_shard_size,
  675. )
  676. self.packed_indices: Optional[torch.Tensor] = None
  677. self.standard_indices: Optional[torch.Tensor] = None
  678. # lazily initialized.
  679. self.indices: torch.Tensor
  680. self.indices_len: List[int]
  681. def reset_lora(self, index: int):
  682. self.lora_a_stacked[0][index] = 0
  683. self.lora_b_stacked[0][index] = 0
  684. self.lora_a_stacked[1][index] = 0
  685. self.lora_b_stacked[1][index] = 0
  686. self.lora_a_stacked[2][index] = 0
  687. self.lora_b_stacked[2][index] = 0
  688. def slice_lora_a(
  689. self, lora_a: List[Union[torch.Tensor, None]]
  690. ) -> List[Union[torch.Tensor, None]]:
  691. return lora_a
  692. def slice_lora_b(
  693. self, lora_b: List[Union[torch.Tensor, None]]
  694. ) -> List[Union[torch.Tensor, None]]:
  695. lora_b_q, lora_b_k, lora_b_v = None, None, None
  696. if lora_b[0] is not None:
  697. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  698. self.q_shard_id:self.q_proj_shard_size *
  699. (self.q_shard_id + 1), ]
  700. if lora_b[1] is not None:
  701. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  702. self.kv_shard_id:self.kv_proj_shard_size *
  703. (self.kv_shard_id + 1), ]
  704. if lora_b[2] is not None:
  705. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  706. self.kv_shard_id:self.kv_proj_shard_size *
  707. (self.kv_shard_id + 1), ]
  708. lora_b = [lora_b_q, lora_b_k, lora_b_v]
  709. return lora_b
  710. def set_lora(
  711. self,
  712. index: int,
  713. lora_a: torch.Tensor,
  714. lora_b: torch.Tensor,
  715. embeddings_tensor: Optional[torch.Tensor],
  716. ):
  717. self.reset_lora(index)
  718. if self.tp_size > 1:
  719. lora_a = self.slice_lora_a(lora_a)
  720. lora_b = self.slice_lora_b(lora_b)
  721. if lora_b[0] is not None:
  722. lora_b_q = lora_b[0]
  723. self.lora_b_stacked[0][
  724. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  725. lora_b_q.T, non_blocking=True)
  726. if lora_b[1] is not None:
  727. lora_b_k = lora_b[1]
  728. self.lora_b_stacked[1][
  729. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  730. lora_b_k.T, non_blocking=True)
  731. if lora_b[2] is not None:
  732. lora_b_v = lora_b[2]
  733. self.lora_b_stacked[2][
  734. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  735. lora_b_v.T, non_blocking=True)
  736. if lora_a[0] is not None:
  737. self.lora_a_stacked[0][
  738. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  739. lora_a[0].T, non_blocking=True)
  740. if lora_a[1] is not None:
  741. self.lora_a_stacked[1][
  742. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  743. lora_a[1].T, non_blocking=True)
  744. if lora_a[2] is not None:
  745. self.lora_a_stacked[2][
  746. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  747. lora_a[2].T, non_blocking=True)
  748. def apply(self, x: torch.Tensor,
  749. bias: Optional[torch.Tensor]) -> torch.Tensor:
  750. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  751. self.punica_wrapper.add_lora_packed_nslice(output, x,
  752. self.lora_a_stacked,
  753. self.lora_b_stacked, 1.0,
  754. self.output_slices)
  755. return output
  756. @classmethod
  757. @_not_fully_sharded_can_replace
  758. def can_replace_layer(
  759. cls,
  760. source_layer: nn.Module,
  761. lora_config: LoRAConfig,
  762. packed_modules_list: List,
  763. model_config: Optional[PretrainedConfig],
  764. ) -> bool:
  765. return (type(source_layer) is QKVParallelLinear
  766. and len(packed_modules_list) == 3)
  767. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  768. def __init__(self, base_layer: RowParallelLinear) -> None:
  769. super().__init__()
  770. self.base_layer = base_layer
  771. self.input_size = self.base_layer.input_size_per_partition
  772. self.output_size = self.base_layer.output_size
  773. self.device = _get_lora_device(self.base_layer)
  774. def create_lora_weights(
  775. self,
  776. max_loras: int,
  777. lora_config: LoRAConfig,
  778. model_config: Optional[PretrainedConfig] = None,
  779. ) -> None:
  780. self.lora_config = lora_config
  781. self.tp_rank = get_tensor_model_parallel_rank()
  782. self.lora_a_stacked = torch.zeros(
  783. (
  784. max_loras,
  785. 1,
  786. lora_config.max_lora_rank,
  787. self.input_size,
  788. ),
  789. dtype=lora_config.lora_dtype,
  790. device=self.device,
  791. )
  792. tp_size = get_tensor_model_parallel_world_size()
  793. lora_b_output_size_per_partition = (
  794. self.output_size if not lora_config.fully_sharded_loras else
  795. divide(self.output_size, tp_size))
  796. self.lora_b_stacked = torch.zeros(
  797. (
  798. max_loras,
  799. 1,
  800. lora_b_output_size_per_partition,
  801. lora_config.max_lora_rank,
  802. ),
  803. dtype=lora_config.lora_dtype,
  804. device=self.device,
  805. )
  806. def reset_lora(self, index: int):
  807. self.lora_a_stacked[index] = 0
  808. self.lora_b_stacked[index] = 0
  809. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  810. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  811. shard_size = self.input_size
  812. start_idx = tensor_model_parallel_rank * shard_size
  813. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  814. lora_a = lora_a[start_idx:end_idx, :]
  815. return lora_a
  816. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  817. return lora_b
  818. def set_lora(
  819. self,
  820. index: int,
  821. lora_a: torch.Tensor,
  822. lora_b: torch.Tensor,
  823. embeddings_tensor: Optional[torch.Tensor],
  824. ):
  825. self.reset_lora(index)
  826. if self.base_layer.tp_size > 1:
  827. lora_a = self.slice_lora_a(lora_a)
  828. lora_b = self.slice_lora_b(lora_b)
  829. self.lora_a_stacked[index,
  830. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  831. lora_a.T, non_blocking=True)
  832. self.lora_b_stacked[index,
  833. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  834. lora_b.T, non_blocking=True)
  835. def apply(self, x: torch.Tensor) -> torch.Tensor:
  836. output = self.base_layer.quant_method.apply(self.base_layer, x)
  837. self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
  838. self.lora_b_stacked, 1.0)
  839. return output
  840. def forward(self, input_):
  841. """Forward of RowParallelLinear
  842. Args:
  843. input_: tensor whose last dimension is `input_size`. If
  844. `input_is_parallel` is set, then the last dimension
  845. is `input_size // tp_size`.
  846. Returns:
  847. - output
  848. - bias
  849. """
  850. # Set up backprop all-reduce.
  851. if self.base_layer.input_is_parallel:
  852. input_parallel = input_
  853. else:
  854. # TODO: simplify code below
  855. tp_rank = get_tensor_model_parallel_rank()
  856. splitted_input = split_tensor_along_last_dim(
  857. input_, num_partitions=self.base_layer.tp_size)
  858. input_parallel = splitted_input[tp_rank].contiguous()
  859. # Matrix multiply.
  860. output_parallel = self.apply(input_parallel)
  861. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  862. output_ = tensor_model_parallel_all_reduce(output_parallel)
  863. else:
  864. output_ = output_parallel
  865. if not self.base_layer.skip_bias_add:
  866. output = (output_ + self.base_layer.bias
  867. if self.base_layer.bias is not None else output_)
  868. output_bias = None
  869. else:
  870. output = output_
  871. output_bias = self.base_layer.bias
  872. return output, output_bias
  873. @property
  874. def weight(self):
  875. return (self.base_layer.weight if hasattr(self.base_layer, "weight")
  876. else self.base_layer.qweight)
  877. @classmethod
  878. @_not_fully_sharded_can_replace
  879. def can_replace_layer(
  880. cls,
  881. source_layer: nn.Module,
  882. lora_config: LoRAConfig,
  883. packed_modules_list: List,
  884. model_config: Optional[PretrainedConfig],
  885. ) -> bool:
  886. return type(source_layer) is RowParallelLinear
  887. class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
  888. """
  889. LoRA wrapper for LogitsProcessor, with extra logic to handle the
  890. application of the LoRA adapter and added LoRA vocabulary.
  891. Args:
  892. base_layer: LogitsProcessor layer
  893. hidden_size: hidden size of the model
  894. dtype: data type of the model
  895. device: device of the model
  896. sharded_to_full_mapping: index mapping from sharded vocab to full vocab
  897. received from base_layer.get_sharded_to_full_mapping(). If None,
  898. no reindexing will be done.
  899. """
  900. def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
  901. dtype: torch.dtype, device: torch.device,
  902. sharded_to_full_mapping: Optional[List[int]]) -> None:
  903. super().__init__()
  904. self.base_layer = base_layer
  905. self.hidden_size = hidden_size
  906. self.dtype = dtype
  907. self.device = device
  908. self.tp_size = get_tensor_model_parallel_world_size()
  909. self.tp_rank = get_tensor_model_parallel_rank()
  910. self.sharded_to_full_mapping = sharded_to_full_mapping
  911. @property
  912. def logits_as_input(self):
  913. return self.base_layer.logits_as_input
  914. @property
  915. def vocab_size(self):
  916. return self.base_layer.vocab_size
  917. @property
  918. def scale(self):
  919. return self.base_layer.scale
  920. @property
  921. def soft_cap(self):
  922. return self.base_layer.soft_cap
  923. @property
  924. def use_gather(self):
  925. return self.base_layer.use_gather
  926. @property
  927. def org_vocab_size(self):
  928. return self.base_layer.org_vocab_size
  929. @property
  930. def include_gpu_probs_tensor(self):
  931. return self.base_layer.include_gpu_probs_tensor
  932. @property
  933. def should_modify_greedy_probs_inplace(self):
  934. return self.base_layer.should_modify_greedy_probs_inplace
  935. def create_lora_weights(
  936. self,
  937. max_loras: int,
  938. lora_config: LoRAConfig,
  939. model_config: Optional[PretrainedConfig] = None,
  940. ) -> None:
  941. # TODO: Verify if this condition can be further relaxed
  942. if 32000 < self.base_layer.vocab_size > 257024:
  943. raise ValueError("When using LoRA, vocab size must be "
  944. "32000 >= vocab_size <= 257024, "
  945. f"but got {self.base_layer.vocab_size}.")
  946. self.lora_a_stacked = torch.zeros(
  947. (
  948. max_loras,
  949. 1,
  950. lora_config.max_lora_rank,
  951. self.hidden_size,
  952. ),
  953. dtype=lora_config.lora_dtype,
  954. device=self.device,
  955. )
  956. self.lora_b_stacked = torch.zeros(
  957. (
  958. max_loras,
  959. 1,
  960. # Pad for kernel compatibility
  961. math.ceil(self.base_layer.vocab_size /
  962. lora_config.lora_vocab_padding_size) *
  963. lora_config.lora_vocab_padding_size,
  964. lora_config.max_lora_rank,
  965. ),
  966. dtype=lora_config.lora_dtype,
  967. device=self.device,
  968. )
  969. self.embeddings_tensors = torch.full(
  970. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  971. fill_value=float("-inf"),
  972. dtype=self.dtype,
  973. device=self.device,
  974. )
  975. if self.sharded_to_full_mapping is not None:
  976. self.sharded_to_full_mapping_gpu = torch.tensor(
  977. self.sharded_to_full_mapping,
  978. device=self.device,
  979. dtype=torch.long)
  980. else:
  981. self.sharded_to_full_mapping_gpu = None
  982. def reset_lora(self, index: int):
  983. self.lora_a_stacked[index] = 0
  984. self.lora_b_stacked[index] = 0
  985. self.embeddings_tensors[index] = float("-inf")
  986. def set_lora(
  987. self,
  988. index: int,
  989. lora_a: torch.Tensor,
  990. lora_b: torch.Tensor,
  991. embeddings_tensor: Optional[torch.Tensor],
  992. ):
  993. self.reset_lora(index)
  994. self.lora_a_stacked[index,
  995. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  996. lora_a.T, non_blocking=True)
  997. self.lora_b_stacked[index,
  998. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  999. lora_b.T, non_blocking=True)
  1000. if embeddings_tensor is not None:
  1001. self.embeddings_tensors[
  1002. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  1003. shape[1], ] = embeddings_tensor
  1004. def _get_logits(
  1005. self,
  1006. hidden_states: torch.Tensor,
  1007. lm_head: VocabParallelEmbedding,
  1008. embedding_bias: Optional[torch.Tensor] = None,
  1009. ) -> Optional[torch.Tensor]:
  1010. # Get the logits for the next tokens.
  1011. logits = lm_head.linear_method.apply(lm_head, hidden_states)
  1012. if embedding_bias is not None:
  1013. logits += embedding_bias
  1014. logits = tensor_model_parallel_gather(logits)
  1015. if logits is None:
  1016. return None
  1017. if self.sharded_to_full_mapping_gpu is not None:
  1018. # Reindex full logits tensor to ensure 1:1 mapping between
  1019. # index and token_id
  1020. # Example for:
  1021. # org_vocab_size = 4
  1022. # added_vocab_size = 2
  1023. # pad_to_size = 8
  1024. # tp_size = 2
  1025. # indices: [0, 1, 2, 3, 4, 5, 6, 7]
  1026. # token_id: [0, 1, 4, -1, 2, 3, 5, -1]
  1027. # Therefore, the mapping is expected to be:
  1028. # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
  1029. # we get:
  1030. # indices: [0, 1, 2, 3, 4, 5, 6, 7]
  1031. # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
  1032. logits = logits[:, self.sharded_to_full_mapping_gpu]
  1033. lora_logits = torch.empty(
  1034. self.embeddings_tensors.shape[0] + 1,
  1035. self.embeddings_tensors.shape[1],
  1036. hidden_states.shape[0],
  1037. dtype=self.embeddings_tensors.dtype,
  1038. device=self.embeddings_tensors.device,
  1039. )
  1040. torch.matmul(self.embeddings_tensors,
  1041. hidden_states.T,
  1042. out=lora_logits[:-1])
  1043. lora_logits[-1] = float("-inf")
  1044. lora_logits = lora_logits.mT
  1045. indices_padded = self.punica_wrapper.sampler_indices_padded
  1046. lora_logits = (lora_logits.reshape(
  1047. lora_logits.shape[0] * lora_logits.shape[1],
  1048. lora_logits.shape[2],
  1049. ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
  1050. posinf=float("inf"),
  1051. neginf=float("-inf")))
  1052. logits[:,
  1053. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  1054. lora_logits.shape[1], ] = lora_logits
  1055. # LogitsProcessorWithLoRA always using bgmv
  1056. self.punica_wrapper.add_lora_logits(logits, hidden_states,
  1057. self.lora_a_stacked,
  1058. self.lora_b_stacked, 1.0)
  1059. # Remove paddings in vocab (if any).
  1060. logits = logits[:, :self.base_layer.vocab_size]
  1061. return logits
  1062. def forward(self, *args, **kwargs):
  1063. return type(self.base_layer).forward(self, *args, **kwargs)
  1064. @classmethod
  1065. def can_replace_layer(
  1066. cls,
  1067. source_layer: nn.Module,
  1068. lora_config: LoRAConfig,
  1069. packed_modules_list: List,
  1070. model_config: Optional[PretrainedConfig],
  1071. ) -> bool:
  1072. # Special handling for the LogitsProcessor.
  1073. return False
  1074. class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
  1075. """Implements RoPE-scaled embeddings with linear scaling for
  1076. multiple LoRA adapters with a specialized kernel.
  1077. Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
  1078. which can handle multi lora adapters in a specialied kernel.
  1079. """
  1080. def __init__(self, base_layer: RotaryEmbedding) -> None:
  1081. super().__init__()
  1082. self.base_layer = base_layer
  1083. @property
  1084. def scaling_factors(self):
  1085. return self.base_layer.scaling_factors
  1086. @property
  1087. def rotary_dim(self):
  1088. return self.base_layer.rotary_dim
  1089. def create_lora_weights(
  1090. self,
  1091. max_loras: int,
  1092. lora_config: LoRAConfig,
  1093. model_config: Optional[PretrainedConfig] = None,
  1094. ) -> None:
  1095. scaling_factors = (list(lora_config.long_lora_scaling_factors)
  1096. if lora_config.long_lora_scaling_factors else [])
  1097. base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
  1098. self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
  1099. scaling_factors = sorted(
  1100. list(set([base_scaling_factor] + scaling_factors)))
  1101. self.base_layer = LinearScalingRotaryEmbedding(
  1102. self.base_layer.head_size,
  1103. self.base_layer.rotary_dim,
  1104. self.base_layer.max_position_embeddings,
  1105. self.base_layer.base,
  1106. self.base_layer.is_neox_style,
  1107. scaling_factors,
  1108. self.base_layer.dtype,
  1109. )
  1110. def reset_lora(self, index: int):
  1111. ...
  1112. def set_lora(
  1113. self,
  1114. index: int,
  1115. lora_a: torch.Tensor,
  1116. lora_b: torch.Tensor,
  1117. embeddings_tensor: Optional[torch.Tensor],
  1118. ):
  1119. ...
  1120. def forward(
  1121. self,
  1122. positions: torch.Tensor,
  1123. query: torch.Tensor,
  1124. key: torch.Tensor,
  1125. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1126. return self.base_layer(
  1127. positions,
  1128. query,
  1129. key,
  1130. offsets=self.punica_wrapper.long_lora_indices,
  1131. )
  1132. @property
  1133. def scaling_factor_to_offset(self) -> Dict[float, int]:
  1134. return self.base_layer.scaling_factor_to_offset
  1135. @classmethod
  1136. def can_replace_layer(
  1137. cls,
  1138. source_layer: nn.Module,
  1139. lora_config: LoRAConfig,
  1140. packed_modules_list: List,
  1141. model_config: Optional[PretrainedConfig],
  1142. ) -> bool:
  1143. """Returns True if the layer can be replaced by this LoRA layer."""
  1144. return (type(source_layer) is LinearScalingRotaryEmbedding
  1145. or type(source_layer) is RotaryEmbedding)
  1146. def extra_repr(self) -> str:
  1147. return self.base_layer.extra_repr()