layers.py 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300
  1. # pylint: disable=unused-argument
  2. import math
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from transformers import PretrainedConfig
  9. from aphrodite.common.config import LoRAConfig
  10. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  11. get_tensor_model_parallel_world_size,
  12. split_tensor_along_last_dim,
  13. tensor_model_parallel_all_gather,
  14. tensor_model_parallel_all_reduce,
  15. tensor_model_parallel_gather)
  16. from aphrodite.distributed.utils import divide
  17. from aphrodite.lora.punica import add_lora, add_lora_slice, bgmv
  18. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  19. MergedColumnParallelLinear,
  20. QKVParallelLinear,
  21. RowParallelLinear)
  22. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  23. from aphrodite.modeling.layers.rotary_embedding import (
  24. LinearScalingRotaryEmbedding, RotaryEmbedding)
  25. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  26. VocabParallelEmbedding
  27. if TYPE_CHECKING:
  28. pass
  29. def _get_lora_device(base_layer: nn.Module) -> torch.device:
  30. # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
  31. """Returns the device for where to place the LoRA tensors."""
  32. # unquantizedLinear
  33. if hasattr(base_layer, "weight"):
  34. return base_layer.weight.device
  35. # GPTQ/AWQ/SqueezeLLM
  36. elif hasattr(base_layer, "qweight"):
  37. return base_layer.qweight.device
  38. # marlin
  39. elif hasattr(base_layer, "B"):
  40. return base_layer.B.device
  41. else:
  42. raise ValueError(f"Unsupported base layer: {base_layer}")
  43. def _not_fully_sharded_can_replace(can_replace):
  44. """
  45. decorator which adds the condition of not using fully sharded loras
  46. intended to wrap can_replace_layer()
  47. """
  48. def dec(*args, **kwargs):
  49. decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True
  50. condition = (not kwargs['lora_config'].fully_sharded_loras
  51. if decorate else True)
  52. return can_replace(*args, **kwargs) and condition
  53. return dec
  54. def _apply_lora(
  55. x: torch.Tensor,
  56. lora_a_stacked: torch.Tensor,
  57. lora_b_stacked: torch.Tensor,
  58. indices: torch.Tensor,
  59. output: torch.Tensor,
  60. ):
  61. """Applies lora to each input.
  62. This method applies all loras to each input. It uses the
  63. indices vector to determine which lora yields the
  64. correct output. An index of -1 means no lora should be
  65. applied. This method adds the final lora results to the
  66. output.
  67. Input shapes:
  68. x: (batch_size, hidden_dim)
  69. lora_a_stacked: (num_loras, lora_rank, hidden_dim)
  70. lora_b_stacked: (num_loras, output_dim, lora_rank)
  71. indices: (batch_size)
  72. output: (batch_size, output_dim)
  73. """
  74. org_output = output
  75. x = x.view(-1, x.shape[-1])
  76. output = output.view(-1, output.shape[-1])
  77. indices = indices.view(-1)
  78. add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
  79. return output.view_as(org_output)
  80. def _apply_lora_packed_nslice(
  81. x: torch.Tensor,
  82. lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  83. lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  84. indices: torch.Tensor,
  85. output: torch.Tensor,
  86. output_slices: Tuple[int, ...],
  87. ):
  88. """Applies lora to each input.
  89. This method applies all loras to each input. It uses the
  90. indices vector to determine which lora yields the
  91. correct output. An index of -1 means no lora should be
  92. applied. This method adds the final lora results to the
  93. output.
  94. This method is used for layers that are composed of multiple sublayers
  95. (slices) packed together.
  96. Input shapes:
  97. x: (batch_size, hidden_dim)
  98. lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
  99. lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
  100. indices: (batch_size)
  101. output: (batch_size, q_slice_size + 2*kv_slice_size)
  102. output_slices: n-1 element tuple of (slice_size...),
  103. where n is number of slices
  104. """
  105. org_output = output
  106. x = x.view(-1, x.shape[-1])
  107. output = output.view(-1, output.shape[-1])
  108. indices = indices.view(-1)
  109. offset_left = 0
  110. for slice_idx in range(len(output_slices)):
  111. add_lora_slice(output, x, lora_a_stacked[slice_idx],
  112. lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
  113. output_slices[slice_idx])
  114. offset_left += output_slices[slice_idx]
  115. return output.view_as(org_output)
  116. @dataclass
  117. class LoRAMapping:
  118. # Per every token in input_ids:
  119. index_mapping: Tuple[int, ...]
  120. # Per sampled token:
  121. prompt_mapping: Tuple[int, ...]
  122. def __post_init__(self):
  123. self.index_mapping = tuple(self.index_mapping)
  124. self.prompt_mapping = tuple(self.prompt_mapping)
  125. class BaseLayerWithLoRA(nn.Module):
  126. def slice_lora_a(
  127. self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  128. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  129. """Slice lora a if splitting for tensor parallelism."""
  130. ...
  131. def slice_lora_b(
  132. self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  133. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  134. """Slice lora b if splitting with tensor parallelism."""
  135. ...
  136. def create_lora_weights(
  137. self,
  138. max_loras: int,
  139. lora_config: LoRAConfig,
  140. model_config: Optional[PretrainedConfig] = None) -> None:
  141. """Initializes lora matrices."""
  142. ...
  143. def reset_lora(self, index: int):
  144. """Resets the lora weights at index back to 0."""
  145. ...
  146. def set_lora(
  147. self,
  148. index: int,
  149. lora_a: torch.Tensor,
  150. lora_b: torch.Tensor,
  151. embeddings_tensor: Optional[torch.Tensor],
  152. ):
  153. """Overwrites lora tensors at index."""
  154. ...
  155. def set_mapping(
  156. self,
  157. base_indices: torch.Tensor,
  158. sampler_indices: torch.Tensor,
  159. sampler_indices_padded: torch.Tensor,
  160. embeddings_indices: torch.Tensor,
  161. long_lora_indices: torch.Tensor,
  162. indices_len: List[int],
  163. ):
  164. """Sets the mapping indices."""
  165. ...
  166. @classmethod
  167. def can_replace_layer(cls, source_layer: nn.Module,
  168. lora_config: LoRAConfig, packed_modules_list: List,
  169. model_config: Optional[PretrainedConfig]) -> bool:
  170. """Returns True if the layer can be replaced by this LoRA layer."""
  171. raise NotImplementedError
  172. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
  173. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  174. super().__init__()
  175. self.base_layer = base_layer
  176. self.embeddings_slice: Optional[Tuple[int, int]]
  177. self.embeddings_weights: Optional[torch.Tensor]
  178. def create_lora_weights(
  179. self,
  180. max_loras: int,
  181. lora_config: LoRAConfig,
  182. model_config: Optional[PretrainedConfig] = None) -> None:
  183. lora_vocab_start_idx = self.base_layer.org_vocab_size
  184. weights_idx = None
  185. if self.base_layer.vocab_end_index > lora_vocab_start_idx:
  186. # We can start adding lora weights
  187. weights_idx = max(
  188. lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
  189. self.embeddings_slice = (self.base_layer.vocab_start_index -
  190. self.base_layer.org_vocab_size +
  191. weights_idx,
  192. self.base_layer.vocab_end_index -
  193. self.base_layer.org_vocab_size)
  194. self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
  195. self.embeddings_weights.fill_(0)
  196. else:
  197. self.embeddings_slice = None
  198. self.embeddings_weights = None
  199. self.embeddings_tensors = torch.zeros(
  200. (
  201. max_loras,
  202. lora_config.lora_extra_vocab_size,
  203. self.base_layer.embedding_dim,
  204. ),
  205. dtype=self.base_layer.weight.dtype,
  206. device=self.base_layer.weight.device,
  207. )
  208. self.lora_a_stacked = torch.zeros(
  209. (
  210. max_loras,
  211. self.base_layer.org_vocab_size +
  212. lora_config.lora_extra_vocab_size,
  213. lora_config.max_lora_rank,
  214. ),
  215. dtype=lora_config.lora_dtype,
  216. device=self.base_layer.weight.device,
  217. )
  218. self.lora_b_stacked = torch.zeros(
  219. (
  220. max_loras,
  221. 1,
  222. self.base_layer.embedding_dim,
  223. lora_config.max_lora_rank,
  224. ),
  225. dtype=lora_config.lora_dtype,
  226. device=self.base_layer.weight.device,
  227. )
  228. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  229. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  230. self.lora_a_stacked.shape[2],
  231. )
  232. # Lazily initialized.
  233. self.indices: torch.Tensor
  234. self.indices_len: List[int]
  235. self.embeddings_indices: torch.Tensor
  236. def reset_lora(self, index: int):
  237. self.lora_a_stacked[index] = 0
  238. self.lora_b_stacked[index] = 0
  239. self.embeddings_tensors[index] = 0
  240. def set_lora(
  241. self,
  242. index: int,
  243. lora_a: torch.Tensor,
  244. lora_b: torch.Tensor,
  245. embeddings_tensor: Optional[torch.Tensor],
  246. ):
  247. self.reset_lora(index)
  248. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  249. lora_a, non_blocking=True)
  250. self.lora_b_stacked[index,
  251. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  252. lora_b.T, non_blocking=True)
  253. if embeddings_tensor is not None:
  254. self.embeddings_tensors[
  255. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  256. shape[1]].copy_(embeddings_tensor, non_blocking=True)
  257. if self.embeddings_slice is not None:
  258. # TODO(yard1): Optimize this copy, we don't need to copy
  259. # everything, just the modified part
  260. embeddings = self.embeddings_tensors.view(
  261. self.embeddings_tensors.shape[0] *
  262. self.embeddings_tensors.shape[1],
  263. self.embeddings_tensors.shape[2]
  264. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  265. assert self.embeddings_weights is not None
  266. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  267. def set_mapping(
  268. self,
  269. base_indices: torch.Tensor,
  270. sampler_indices: torch.Tensor,
  271. sampler_indices_padded: torch.Tensor,
  272. embeddings_indices: torch.Tensor,
  273. long_lora_indices: torch.Tensor,
  274. indices_len: List[int],
  275. ):
  276. self.indices = base_indices
  277. self.embeddings_indices = embeddings_indices
  278. self.indices_len = indices_len
  279. def forward(self, x: torch.Tensor) -> torch.Tensor:
  280. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  281. embedding_len = self.indices_len[3]
  282. indices = self.embeddings_indices[1][:embedding_len].view_as(x)
  283. full_lora_a_embeddings = F.embedding(
  284. x + indices,
  285. self.lora_a_stacked_2d,
  286. )
  287. indices = self.embeddings_indices[0][:embedding_len].view_as(x)
  288. full_output = self.base_layer.forward(
  289. x.add_(indices * added_tokens_mask))
  290. full_output_org = full_output
  291. if full_output.ndim == 3:
  292. full_output = full_output.view(
  293. full_output.shape[0] * full_output.shape[1], -1)
  294. if full_lora_a_embeddings.ndim == 3:
  295. full_lora_a_embeddings = full_lora_a_embeddings.view(
  296. full_lora_a_embeddings.shape[0] *
  297. full_lora_a_embeddings.shape[1], -1)
  298. bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
  299. self.indices[:self.indices_len[0]], 0, 1.0)
  300. return full_output.view_as(full_output_org)
  301. @classmethod
  302. def can_replace_layer(cls, source_layer: nn.Module,
  303. lora_config: LoRAConfig, packed_modules_list: List,
  304. model_config: Optional[PretrainedConfig]) -> bool:
  305. return type(source_layer) is VocabParallelEmbedding
  306. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  307. """
  308. LoRA on top of ColumnParallelLinear layer.
  309. LoRA B is sliced for tensor parallelism.
  310. """
  311. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  312. super().__init__()
  313. self.base_layer = base_layer
  314. self.tp_size = get_tensor_model_parallel_world_size()
  315. self.input_size = self.base_layer.input_size
  316. self.output_size = self.base_layer.output_size_per_partition
  317. self.device = _get_lora_device(self.base_layer)
  318. def create_lora_weights(
  319. self,
  320. max_loras: int,
  321. lora_config: LoRAConfig,
  322. model_config: Optional[PretrainedConfig] = None) -> None:
  323. self.lora_config = lora_config
  324. self.tp_size = get_tensor_model_parallel_world_size()
  325. lora_a_output_size_per_partition = (
  326. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  327. else divide(lora_config.max_lora_rank, self.tp_size))
  328. self.lora_a_stacked = torch.zeros(
  329. max_loras,
  330. 1,
  331. lora_a_output_size_per_partition,
  332. self.input_size,
  333. dtype=lora_config.lora_dtype,
  334. device=self.device,
  335. )
  336. self.lora_b_stacked = torch.zeros(
  337. max_loras,
  338. 1,
  339. self.output_size,
  340. lora_config.max_lora_rank,
  341. dtype=lora_config.lora_dtype,
  342. device=self.device,
  343. )
  344. self.output_dim = self.lora_b_stacked.shape[2]
  345. # lazily initialized.
  346. self.indices: torch.Tensor
  347. self.indices_len: List[int]
  348. def reset_lora(self, index: int):
  349. self.lora_a_stacked[index] = 0
  350. self.lora_b_stacked[index] = 0
  351. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  352. return lora_a
  353. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  354. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  355. shard_size = self.output_dim
  356. start_idx = tensor_model_parallel_rank * shard_size
  357. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  358. lora_b = lora_b[:, start_idx:end_idx]
  359. return lora_b
  360. def set_lora(
  361. self,
  362. index: int,
  363. lora_a: torch.Tensor,
  364. lora_b: torch.Tensor,
  365. embeddings_tensor: Optional[torch.Tensor],
  366. ):
  367. self.reset_lora(index)
  368. if self.tp_size > 1:
  369. lora_a = self.slice_lora_a(lora_a)
  370. lora_b = self.slice_lora_b(lora_b)
  371. self.lora_a_stacked[index,
  372. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  373. lora_a.T, non_blocking=True)
  374. self.lora_b_stacked[index,
  375. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  376. lora_b.T, non_blocking=True)
  377. def set_mapping(
  378. self,
  379. base_indices: torch.Tensor,
  380. sampler_indices: torch.Tensor,
  381. sampler_indices_padded: torch.Tensor,
  382. embeddings_indices: torch.Tensor,
  383. long_lora_indices: torch.Tensor,
  384. indices_len: List[int],
  385. ):
  386. self.indices = base_indices
  387. self.indices_len = indices_len
  388. def apply(self, x: torch.Tensor,
  389. bias: Optional[torch.Tensor]) -> torch.Tensor:
  390. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  391. _apply_lora(
  392. x,
  393. self.lora_a_stacked,
  394. self.lora_b_stacked,
  395. self.indices[:self.indices_len[0]],
  396. output,
  397. )
  398. return output
  399. def forward(self, input_):
  400. """Forward of ColumnParallelLinear
  401. Args:
  402. input_: Tensor whose last dimension is `input_size`.
  403. Returns:
  404. - output
  405. - bias
  406. """
  407. bias = (self.base_layer.bias
  408. if not self.base_layer.skip_bias_add else None)
  409. # Matrix multiply.
  410. output_parallel = self.apply(input_, bias)
  411. if self.base_layer.gather_output:
  412. # All-gather across the partitions.
  413. output = tensor_model_parallel_all_gather(output_parallel)
  414. else:
  415. output = output_parallel
  416. output_bias = (self.base_layer.bias
  417. if self.base_layer.skip_bias_add else None)
  418. return output, output_bias
  419. @classmethod
  420. @_not_fully_sharded_can_replace
  421. def can_replace_layer(cls, source_layer: nn.Module,
  422. lora_config: LoRAConfig, packed_modules_list: List,
  423. model_config: Optional[PretrainedConfig]) -> bool:
  424. return type(source_layer) is ColumnParallelLinear or (
  425. type(source_layer) is MergedColumnParallelLinear
  426. and len(packed_modules_list) == 1)
  427. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  428. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  429. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  430. This means we have 2 LoRAs, each applied to one half of the layer.
  431. Both slices must have the same size.
  432. """
  433. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  434. super().__init__(base_layer)
  435. def create_lora_weights(
  436. self,
  437. max_loras: int,
  438. lora_config: LoRAConfig,
  439. model_config: Optional[PretrainedConfig] = None) -> None:
  440. self.lora_config = lora_config
  441. n_slices = 2
  442. if not (len(self.base_layer.output_sizes) == n_slices
  443. and self.base_layer.output_sizes[0]
  444. == self.base_layer.output_sizes[1]):
  445. raise ValueError(
  446. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  447. "the same size.")
  448. self.tp_size = get_tensor_model_parallel_world_size()
  449. self.tp_rank = get_tensor_model_parallel_rank()
  450. lora_a_output_size_per_partition = (
  451. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  452. else divide(lora_config.max_lora_rank, self.tp_size))
  453. self.lora_a_stacked = tuple(
  454. torch.zeros(
  455. max_loras,
  456. 1,
  457. lora_a_output_size_per_partition,
  458. self.input_size,
  459. dtype=lora_config.lora_dtype,
  460. device=self.device,
  461. ) for _ in range(n_slices))
  462. self.lora_b_stacked = tuple(
  463. torch.zeros(
  464. max_loras,
  465. 1,
  466. self.output_size // 2,
  467. lora_config.max_lora_rank,
  468. dtype=lora_config.lora_dtype,
  469. device=self.device,
  470. ) for _ in range(n_slices))
  471. self.output_dim = self.lora_b_stacked[0].shape[2]
  472. # Lazily initialized.
  473. self.indices: torch.Tensor
  474. def reset_lora(self, index: int):
  475. self.lora_a_stacked[0][index] = 0
  476. self.lora_a_stacked[1][index] = 0
  477. self.lora_b_stacked[0][index] = 0
  478. self.lora_b_stacked[1][index] = 0
  479. def slice_lora_a(
  480. self, lora_a: List[Union[torch.Tensor, None]]
  481. ) -> List[Union[torch.Tensor, None]]:
  482. return lora_a
  483. def slice_lora_b(
  484. self, lora_b: List[Union[torch.Tensor, None]]
  485. ) -> List[Union[torch.Tensor, None]]:
  486. if lora_b[0] is None or lora_b[1] is None:
  487. return lora_b
  488. shard_size = self.output_dim
  489. start_idx = self.tp_rank * shard_size
  490. end_idx = (self.tp_rank + 1) * shard_size
  491. lora_b = [
  492. lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx]
  493. ]
  494. return lora_b
  495. def set_lora(
  496. self,
  497. index: int,
  498. lora_a: torch.Tensor,
  499. lora_b: torch.Tensor,
  500. embeddings_tensor: Optional[torch.Tensor],
  501. ):
  502. self.reset_lora(index)
  503. if self.tp_size > 1:
  504. lora_a = self.slice_lora_a(lora_a)
  505. lora_b = self.slice_lora_b(lora_b)
  506. if lora_a[0] is not None:
  507. self.lora_a_stacked[0][
  508. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  509. lora_a[0].T, non_blocking=True)
  510. self.lora_b_stacked[0][
  511. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  512. lora_b[0].T, non_blocking=True)
  513. if lora_a[1] is not None:
  514. self.lora_a_stacked[1][
  515. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  516. lora_a[1].T, non_blocking=True)
  517. self.lora_b_stacked[1][
  518. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  519. lora_b[1].T, non_blocking=True)
  520. def apply(self, x: torch.Tensor,
  521. bias: Optional[torch.Tensor]) -> torch.Tensor:
  522. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  523. _apply_lora_packed_nslice(
  524. x,
  525. self.lora_a_stacked,
  526. self.lora_b_stacked,
  527. self.indices[:self.indices_len[0]],
  528. output,
  529. (self.output_dim, self.output_dim),
  530. )
  531. return output
  532. @classmethod
  533. @_not_fully_sharded_can_replace
  534. def can_replace_layer(cls, source_layer: nn.Module,
  535. lora_config: LoRAConfig, packed_modules_list: List,
  536. model_config: Optional[PretrainedConfig]) -> bool:
  537. return type(source_layer) is MergedColumnParallelLinear and len(
  538. packed_modules_list) == 2
  539. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  540. """
  541. ColumnParallelLinear layer that is specifically designed for
  542. qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
  543. only contains a single LoRA within their qkv_proj layer.
  544. During inference with Tensor Parallel, the weights of lora_b
  545. must be accurately partitioned according to the respective ranks.
  546. Q slice may have different shape than K and V slices (which both have
  547. the same shape).
  548. """
  549. def __init__(self, base_layer: QKVParallelLinear) -> None:
  550. super().__init__(base_layer)
  551. self.tp_size = get_tensor_model_parallel_world_size()
  552. self.q_proj_total_size = (self.base_layer.total_num_heads *
  553. self.base_layer.head_size)
  554. self.q_proj_shard_size = (self.base_layer.num_heads *
  555. self.base_layer.head_size)
  556. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  557. self.base_layer.head_size)
  558. self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
  559. self.base_layer.head_size)
  560. def set_lora(
  561. self,
  562. index: int,
  563. lora_a: torch.Tensor,
  564. lora_b: torch.Tensor,
  565. embeddings_tensor: Optional[torch.Tensor],
  566. ):
  567. self.reset_lora(index)
  568. if self.tp_size > 1:
  569. tp_rank = get_tensor_model_parallel_rank()
  570. self.q_shard_id = tp_rank
  571. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  572. lora_b_q = lora_b[:, self.q_proj_shard_size *
  573. self.q_shard_id:self.q_proj_shard_size *
  574. (self.q_shard_id + 1)]
  575. k_offset = self.q_proj_total_size
  576. lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
  577. self.kv_shard_id:k_offset +
  578. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  579. v_offset = k_offset + self.kv_proj_total_size
  580. lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
  581. self.kv_shard_id:v_offset +
  582. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  583. lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
  584. self.lora_a_stacked[index,
  585. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  586. lora_a.T, non_blocking=True)
  587. self.lora_b_stacked[index,
  588. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  589. lora_b.T, non_blocking=True)
  590. @classmethod
  591. def can_replace_layer(cls, source_layer: nn.Module,
  592. lora_config: LoRAConfig, packed_modules_list: List,
  593. model_config: Optional[PretrainedConfig]) -> bool:
  594. return type(source_layer) is QKVParallelLinear and len(
  595. packed_modules_list) == 1
  596. class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  597. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  598. packed together in qkv proj fashion
  599. (q_proj + k_proj + v_proj -> qkv_proj).
  600. This means we have 3 LoRAs, each applied to one slice of the layer.
  601. Q slice may have different shape than K and V slices (which both have
  602. the same shape).
  603. """
  604. def __init__(self, base_layer: QKVParallelLinear) -> None:
  605. super().__init__(base_layer)
  606. def create_lora_weights(
  607. self,
  608. max_loras: int,
  609. lora_config: LoRAConfig,
  610. model_config: Optional[PretrainedConfig] = None) -> None:
  611. self.lora_config = lora_config
  612. self.tp_size = get_tensor_model_parallel_world_size()
  613. self.tp_rank = get_tensor_model_parallel_rank()
  614. self.q_proj_shard_size = (self.base_layer.num_heads *
  615. self.base_layer.head_size)
  616. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  617. self.base_layer.head_size)
  618. self.q_shard_id = self.tp_rank
  619. self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
  620. lora_a_output_size_per_partition = (
  621. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  622. else divide(lora_config.max_lora_rank, self.tp_size))
  623. # q, k, v
  624. self.lora_a_stacked = (
  625. torch.zeros(
  626. max_loras,
  627. 1,
  628. lora_a_output_size_per_partition,
  629. self.input_size,
  630. dtype=lora_config.lora_dtype,
  631. device=self.device,
  632. ),
  633. torch.zeros(
  634. max_loras,
  635. 1,
  636. lora_a_output_size_per_partition,
  637. self.input_size,
  638. dtype=lora_config.lora_dtype,
  639. device=self.device,
  640. ),
  641. torch.zeros(
  642. max_loras,
  643. 1,
  644. lora_a_output_size_per_partition,
  645. self.input_size,
  646. dtype=lora_config.lora_dtype,
  647. device=self.device,
  648. ),
  649. )
  650. self.lora_b_stacked = (
  651. torch.zeros(
  652. max_loras,
  653. 1,
  654. self.q_proj_shard_size,
  655. lora_config.max_lora_rank,
  656. dtype=lora_config.lora_dtype,
  657. device=self.device,
  658. ),
  659. torch.zeros(
  660. max_loras,
  661. 1,
  662. self.kv_proj_shard_size,
  663. lora_config.max_lora_rank,
  664. dtype=lora_config.lora_dtype,
  665. device=self.device,
  666. ),
  667. torch.zeros(
  668. max_loras,
  669. 1,
  670. self.kv_proj_shard_size,
  671. lora_config.max_lora_rank,
  672. dtype=lora_config.lora_dtype,
  673. device=self.device,
  674. ),
  675. )
  676. self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
  677. self.kv_proj_shard_size)
  678. self.packed_indices: Optional[torch.Tensor] = None
  679. self.standard_indices: Optional[torch.Tensor] = None
  680. # lazily initialized.
  681. self.indices_len: List[int]
  682. def reset_lora(self, index: int):
  683. self.lora_a_stacked[0][index] = 0
  684. self.lora_b_stacked[0][index] = 0
  685. self.lora_a_stacked[1][index] = 0
  686. self.lora_b_stacked[1][index] = 0
  687. self.lora_a_stacked[2][index] = 0
  688. self.lora_b_stacked[2][index] = 0
  689. def slice_lora_a(
  690. self, lora_a: List[Union[torch.Tensor, None]]
  691. ) -> List[Union[torch.Tensor, None]]:
  692. return lora_a
  693. def slice_lora_b(
  694. self, lora_b: List[Union[torch.Tensor, None]]
  695. ) -> List[Union[torch.Tensor, None]]:
  696. lora_b_q, lora_b_k, lora_b_v = None, None, None
  697. if lora_b[0] is not None:
  698. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  699. self.q_shard_id:self.q_proj_shard_size *
  700. (self.q_shard_id + 1)]
  701. if lora_b[1] is not None:
  702. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  703. self.kv_shard_id:self.kv_proj_shard_size *
  704. (self.kv_shard_id + 1)]
  705. if lora_b[2] is not None:
  706. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  707. self.kv_shard_id:self.kv_proj_shard_size *
  708. (self.kv_shard_id + 1)]
  709. lora_b = [lora_b_q, lora_b_k, lora_b_v]
  710. return lora_b
  711. def set_lora(
  712. self,
  713. index: int,
  714. lora_a: torch.Tensor,
  715. lora_b: torch.Tensor,
  716. embeddings_tensor: Optional[torch.Tensor],
  717. ):
  718. self.reset_lora(index)
  719. if self.tp_size > 1:
  720. lora_a = self.slice_lora_a(lora_a)
  721. lora_b = self.slice_lora_b(lora_b)
  722. if lora_b[0] is not None:
  723. lora_b_q = lora_b[0]
  724. self.lora_b_stacked[0][
  725. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  726. lora_b_q.T, non_blocking=True)
  727. if lora_b[1] is not None:
  728. lora_b_k = lora_b[1]
  729. self.lora_b_stacked[1][
  730. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  731. lora_b_k.T, non_blocking=True)
  732. if lora_b[2] is not None:
  733. lora_b_v = lora_b[2]
  734. self.lora_b_stacked[2][
  735. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  736. lora_b_v.T, non_blocking=True)
  737. if lora_a[0] is not None:
  738. self.lora_a_stacked[0][
  739. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  740. lora_a[0].T, non_blocking=True)
  741. if lora_a[1] is not None:
  742. self.lora_a_stacked[1][
  743. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  744. lora_a[1].T, non_blocking=True)
  745. if lora_a[2] is not None:
  746. self.lora_a_stacked[2][
  747. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  748. lora_a[2].T, non_blocking=True)
  749. def apply(self, x: torch.Tensor,
  750. bias: Optional[torch.Tensor]) -> torch.Tensor:
  751. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  752. _apply_lora_packed_nslice(
  753. x,
  754. self.lora_a_stacked,
  755. self.lora_b_stacked,
  756. self.indices[:self.indices_len[0]],
  757. output,
  758. self.output_slices,
  759. )
  760. return output
  761. @classmethod
  762. @_not_fully_sharded_can_replace
  763. def can_replace_layer(cls, source_layer: nn.Module,
  764. lora_config: LoRAConfig, packed_modules_list: List,
  765. model_config: Optional[PretrainedConfig]) -> bool:
  766. return type(source_layer) is QKVParallelLinear and len(
  767. packed_modules_list) == 3
  768. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  769. def __init__(self, base_layer: RowParallelLinear) -> None:
  770. super().__init__()
  771. self.base_layer = base_layer
  772. self.input_size = self.base_layer.input_size_per_partition
  773. self.output_size = self.base_layer.output_size
  774. self.device = _get_lora_device(self.base_layer)
  775. def create_lora_weights(
  776. self,
  777. max_loras: int,
  778. lora_config: LoRAConfig,
  779. model_config: Optional[PretrainedConfig] = None) -> None:
  780. self.lora_config = lora_config
  781. self.tp_rank = get_tensor_model_parallel_rank()
  782. self.lora_a_stacked = torch.zeros(
  783. (
  784. max_loras,
  785. 1,
  786. lora_config.max_lora_rank,
  787. self.input_size,
  788. ),
  789. dtype=lora_config.lora_dtype,
  790. device=self.device,
  791. )
  792. tp_size = get_tensor_model_parallel_world_size()
  793. lora_b_output_size_per_partition = (
  794. self.output_size if not lora_config.fully_sharded_loras else
  795. divide(self.output_size, tp_size))
  796. self.lora_b_stacked = torch.zeros(
  797. (
  798. max_loras,
  799. 1,
  800. lora_b_output_size_per_partition,
  801. lora_config.max_lora_rank,
  802. ),
  803. dtype=lora_config.lora_dtype,
  804. device=self.device,
  805. )
  806. # Lazily initialized
  807. self.indices: torch.Tensor
  808. self.indices_len: List[int]
  809. def reset_lora(self, index: int):
  810. self.lora_a_stacked[index] = 0
  811. self.lora_b_stacked[index] = 0
  812. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  813. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  814. shard_size = self.input_size
  815. start_idx = tensor_model_parallel_rank * shard_size
  816. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  817. lora_a = lora_a[start_idx:end_idx, :]
  818. return lora_a
  819. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  820. return lora_b
  821. def set_lora(
  822. self,
  823. index: int,
  824. lora_a: torch.Tensor,
  825. lora_b: torch.Tensor,
  826. embeddings_tensor: Optional[torch.Tensor],
  827. ):
  828. self.reset_lora(index)
  829. if self.base_layer.tp_size > 1:
  830. lora_a = self.slice_lora_a(lora_a)
  831. lora_b = self.slice_lora_b(lora_b)
  832. self.lora_a_stacked[index,
  833. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  834. lora_a.T, non_blocking=True)
  835. self.lora_b_stacked[index,
  836. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  837. lora_b.T, non_blocking=True)
  838. def set_mapping(
  839. self,
  840. base_indices: torch.Tensor,
  841. sampler_indices: torch.Tensor,
  842. sampler_indices_padded: torch.Tensor,
  843. embeddings_indices: torch.Tensor,
  844. long_lora_indices: torch.Tensor,
  845. indices_len: List[int],
  846. ):
  847. self.indices = base_indices
  848. self.indices_len = indices_len
  849. def apply(self, x: torch.Tensor) -> torch.Tensor:
  850. output = self.base_layer.quant_method.apply(self.base_layer, x)
  851. _apply_lora(
  852. x,
  853. self.lora_a_stacked,
  854. self.lora_b_stacked,
  855. self.indices[:self.indices_len[0]],
  856. output,
  857. )
  858. return output
  859. def forward(self, input_):
  860. """Forward of RowParallelLinear
  861. Args:
  862. input_: tensor whose last dimension is `input_size`. If
  863. `input_is_parallel` is set, then the last dimension
  864. is `input_size // tp_size`.
  865. Returns:
  866. - output
  867. - bias
  868. """
  869. # Set up backprop all-reduce.
  870. if self.base_layer.input_is_parallel:
  871. input_parallel = input_
  872. else:
  873. # TODO: simplify code below
  874. tp_rank = get_tensor_model_parallel_rank()
  875. splitted_input = split_tensor_along_last_dim(
  876. input_, num_partitions=self.base_layer.tp_size)
  877. input_parallel = splitted_input[tp_rank].contiguous()
  878. # Matrix multiply.
  879. output_parallel = self.apply(input_parallel)
  880. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  881. output_ = tensor_model_parallel_all_reduce(output_parallel)
  882. else:
  883. output_ = output_parallel
  884. if not self.base_layer.skip_bias_add:
  885. output = (output_ + self.base_layer.bias
  886. if self.base_layer.bias is not None else output_)
  887. output_bias = None
  888. else:
  889. output = output_
  890. output_bias = self.base_layer.bias
  891. return output, output_bias
  892. @property
  893. def weight(self):
  894. return self.base_layer.weight if hasattr(
  895. self.base_layer, "weight") else self.base_layer.qweight
  896. @classmethod
  897. @_not_fully_sharded_can_replace
  898. def can_replace_layer(cls, source_layer: nn.Module,
  899. lora_config: LoRAConfig, packed_modules_list: List,
  900. model_config: Optional[PretrainedConfig]) -> bool:
  901. return type(source_layer) is RowParallelLinear
  902. class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
  903. def __init__(
  904. self,
  905. base_layer: LogitsProcessor,
  906. hidden_size: int,
  907. dtype: torch.dtype,
  908. device: torch.device,
  909. ) -> None:
  910. super().__init__()
  911. self.base_layer = base_layer
  912. self.hidden_size = hidden_size
  913. self.dtype = dtype
  914. self.device = device
  915. @property
  916. def logits_as_input(self):
  917. return self.base_layer.logits_as_input
  918. @property
  919. def vocab_size(self):
  920. return self.base_layer.vocab_size
  921. @property
  922. def scale(self):
  923. return self.base_layer.scale
  924. @property
  925. def org_vocab_size(self):
  926. return self.base_layer.org_vocab_size
  927. @property
  928. def include_gpu_probs_tensor(self):
  929. return self.base_layer.include_gpu_probs_tensor
  930. def create_lora_weights(
  931. self,
  932. max_loras: int,
  933. lora_config: LoRAConfig,
  934. model_config: Optional[PretrainedConfig] = None,
  935. ) -> None:
  936. # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
  937. if 32000 < self.base_layer.vocab_size > 128512:
  938. raise ValueError("When using LoRA, vocab size must be "
  939. "32000 >= vocab_size <= 128512")
  940. self.lora_a_stacked = torch.zeros(
  941. (
  942. max_loras,
  943. 1,
  944. lora_config.max_lora_rank,
  945. self.hidden_size,
  946. ),
  947. dtype=lora_config.lora_dtype,
  948. device=self.device,
  949. )
  950. self.lora_b_stacked = torch.zeros(
  951. (
  952. max_loras,
  953. 1,
  954. # Pad for kernel compatibility
  955. math.ceil(self.base_layer.vocab_size /
  956. lora_config.lora_vocab_padding_size) *
  957. lora_config.lora_vocab_padding_size,
  958. lora_config.max_lora_rank,
  959. ),
  960. dtype=lora_config.lora_dtype,
  961. device=self.device,
  962. )
  963. self.embeddings_tensors = torch.full(
  964. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  965. fill_value=float("-inf"),
  966. dtype=self.dtype,
  967. device=self.device,
  968. )
  969. # Lazily initialized.
  970. self.indices: torch.Tensor
  971. self.indices_len: List[int]
  972. self.indices_padded: torch.Tensor
  973. def reset_lora(self, index: int):
  974. self.lora_a_stacked[index] = 0
  975. self.lora_b_stacked[index] = 0
  976. self.embeddings_tensors[index] = float("-inf")
  977. def set_lora(
  978. self,
  979. index: int,
  980. lora_a: torch.Tensor,
  981. lora_b: torch.Tensor,
  982. embeddings_tensor: Optional[torch.Tensor],
  983. ):
  984. self.reset_lora(index)
  985. self.lora_a_stacked[index,
  986. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  987. lora_a.T, non_blocking=True)
  988. self.lora_b_stacked[index,
  989. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  990. lora_b.T, non_blocking=True)
  991. if embeddings_tensor is not None:
  992. self.embeddings_tensors[
  993. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  994. shape[1], ] = embeddings_tensor
  995. def set_mapping(
  996. self,
  997. base_indices: torch.Tensor,
  998. sampler_indices: torch.Tensor,
  999. sampler_indices_padded: torch.Tensor,
  1000. embeddings_indices: torch.Tensor,
  1001. long_lora_indices: torch.Tensor,
  1002. indices_len: List[int],
  1003. ):
  1004. self.indices = sampler_indices
  1005. self.indices_padded = sampler_indices_padded
  1006. self.indices_len = indices_len
  1007. def _get_logits(
  1008. self,
  1009. hidden_states: torch.Tensor,
  1010. embedding: torch.Tensor,
  1011. embedding_bias: Optional[torch.Tensor] = None,
  1012. ) -> Optional[torch.Tensor]:
  1013. # Get the logits for the next tokens.
  1014. logits = torch.matmul(hidden_states, embedding.t())
  1015. if embedding_bias is not None:
  1016. logits += embedding_bias
  1017. logits = tensor_model_parallel_gather(logits)
  1018. if logits is None:
  1019. return None
  1020. lora_logits = torch.empty(
  1021. self.embeddings_tensors.shape[0] + 1,
  1022. self.embeddings_tensors.shape[1],
  1023. hidden_states.shape[0],
  1024. dtype=self.embeddings_tensors.dtype,
  1025. device=self.embeddings_tensors.device,
  1026. )
  1027. torch.matmul(self.embeddings_tensors,
  1028. hidden_states.T,
  1029. out=lora_logits[:-1])
  1030. lora_logits[-1] = float("-inf")
  1031. lora_logits = lora_logits.mT
  1032. lora_logits = (lora_logits.reshape(
  1033. lora_logits.shape[0] * lora_logits.shape[1],
  1034. lora_logits.shape[2],
  1035. ).index_select(0,
  1036. self.indices_padded[:self.indices_len[2]]).nan_to_num_(
  1037. nan=float("-inf"),
  1038. posinf=float("inf"),
  1039. neginf=float("-inf")))
  1040. logits[:,
  1041. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  1042. lora_logits.shape[1]] = lora_logits
  1043. _apply_lora(
  1044. hidden_states,
  1045. self.lora_a_stacked,
  1046. self.lora_b_stacked,
  1047. self.indices[:self.indices_len[1]],
  1048. logits,
  1049. )
  1050. # Remove paddings in vocab (if any).
  1051. logits = logits[:, :self.base_layer.vocab_size]
  1052. return logits
  1053. def forward(self, *args, **kwargs):
  1054. return type(self.base_layer).forward(self, *args, **kwargs)
  1055. @classmethod
  1056. def can_replace_layer(cls, source_layer: nn.Module,
  1057. lora_config: LoRAConfig, packed_modules_list: List,
  1058. model_config: Optional[PretrainedConfig]) -> bool:
  1059. # Special handling for the LogitsProcessor.
  1060. return False
  1061. class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
  1062. """Implements RoPE-scaled embeddings with linear scaling for
  1063. multiple LoRA adapters with a specialized kernel.
  1064. Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
  1065. which can handle multi lora adapters in a specialied kernel.
  1066. """
  1067. def __init__(self, base_layer: RotaryEmbedding) -> None:
  1068. super().__init__()
  1069. self.base_layer = base_layer
  1070. # Lazily initialized
  1071. self.long_lora_indices: torch.Tensor
  1072. self.indices_len: List[int]
  1073. @property
  1074. def scaling_factors(self):
  1075. return self.base_layer.scaling_factors
  1076. @property
  1077. def rotary_dim(self):
  1078. return self.base_layer.rotary_dim
  1079. def create_lora_weights(
  1080. self,
  1081. max_loras: int,
  1082. lora_config: LoRAConfig,
  1083. model_config: Optional[PretrainedConfig] = None,
  1084. ) -> None:
  1085. scaling_factors = list(
  1086. lora_config.long_lora_scaling_factors
  1087. ) if lora_config.long_lora_scaling_factors else []
  1088. base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
  1089. self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
  1090. scaling_factors = sorted(
  1091. list(set([base_scaling_factor] + scaling_factors)))
  1092. self.base_layer = LinearScalingRotaryEmbedding(
  1093. self.base_layer.head_size,
  1094. self.base_layer.rotary_dim,
  1095. self.base_layer.max_position_embeddings,
  1096. self.base_layer.base,
  1097. self.base_layer.is_neox_style,
  1098. scaling_factors,
  1099. self.base_layer.dtype,
  1100. )
  1101. def reset_lora(self, index: int):
  1102. ...
  1103. def set_lora(
  1104. self,
  1105. index: int,
  1106. lora_a: torch.Tensor,
  1107. lora_b: torch.Tensor,
  1108. embeddings_tensor: Optional[torch.Tensor],
  1109. ):
  1110. ...
  1111. def set_mapping(
  1112. self,
  1113. base_indices: torch.Tensor,
  1114. sampler_indices: torch.Tensor,
  1115. sampler_indices_padded: torch.Tensor,
  1116. embeddings_indices: torch.Tensor,
  1117. long_lora_indices: torch.Tensor,
  1118. indices_len: List[int],
  1119. ):
  1120. self.long_lora_indices = long_lora_indices
  1121. self.indices_len = indices_len
  1122. def forward(
  1123. self,
  1124. positions: torch.Tensor,
  1125. query: torch.Tensor,
  1126. key: torch.Tensor,
  1127. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1128. return self.base_layer(
  1129. positions,
  1130. query,
  1131. key,
  1132. offsets=self.long_lora_indices[:self.indices_len[4]])
  1133. @property
  1134. def scaling_factor_to_offset(self) -> Dict[float, int]:
  1135. return self.base_layer.scaling_factor_to_offset
  1136. @classmethod
  1137. def can_replace_layer(cls, source_layer: nn.Module,
  1138. lora_config: LoRAConfig, packed_modules_list: List,
  1139. model_config: Optional[PretrainedConfig]) -> bool:
  1140. """Returns True if the layer can be replaced by this LoRA layer."""
  1141. return type(source_layer) is LinearScalingRotaryEmbedding or type(
  1142. source_layer) is RotaryEmbedding
  1143. def extra_repr(self) -> str:
  1144. return self.base_layer.extra_repr()