layers.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337
  1. # pylint: disable=unused-argument
  2. import math
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from transformers import PretrainedConfig
  9. from aphrodite.common.config import LoRAConfig
  10. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  11. get_tensor_model_parallel_world_size,
  12. split_tensor_along_last_dim,
  13. tensor_model_parallel_all_gather,
  14. tensor_model_parallel_all_reduce,
  15. tensor_model_parallel_gather)
  16. from aphrodite.distributed.utils import divide
  17. from aphrodite.lora.punica import add_lora, add_lora_slice, bgmv
  18. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  19. MergedColumnParallelLinear,
  20. QKVParallelLinear,
  21. RowParallelLinear)
  22. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  23. from aphrodite.modeling.layers.rotary_embedding import (
  24. LinearScalingRotaryEmbedding, RotaryEmbedding)
  25. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  26. VocabParallelEmbedding
  27. if TYPE_CHECKING:
  28. pass
  29. def _get_lora_device(base_layer: nn.Module) -> torch.device:
  30. # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
  31. """Returns the device for where to place the LoRA tensors."""
  32. # unquantizedLinear
  33. if hasattr(base_layer, "weight"):
  34. return base_layer.weight.device
  35. # GPTQ/AWQ/SqueezeLLM
  36. elif hasattr(base_layer, "qweight"):
  37. return base_layer.qweight.device
  38. # marlin
  39. elif hasattr(base_layer, "B"):
  40. return base_layer.B.device
  41. else:
  42. raise ValueError(f"Unsupported base layer: {base_layer}")
  43. def _not_fully_sharded_can_replace(can_replace):
  44. """
  45. decorator which adds the condition of not using fully sharded loras
  46. intended to wrap can_replace_layer()
  47. """
  48. def dec(*args, **kwargs):
  49. decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True
  50. condition = (not kwargs['lora_config'].fully_sharded_loras
  51. if decorate else True)
  52. return can_replace(*args, **kwargs) and condition
  53. return dec
  54. def _apply_lora(
  55. x: torch.Tensor,
  56. lora_a_stacked: torch.Tensor,
  57. lora_b_stacked: torch.Tensor,
  58. indices: torch.Tensor,
  59. output: torch.Tensor,
  60. ):
  61. """Applies lora to each input.
  62. This method applies all loras to each input. It uses the
  63. indices vector to determine which lora yields the
  64. correct output. An index of -1 means no lora should be
  65. applied. This method adds the final lora results to the
  66. output.
  67. Input shapes:
  68. x: (batch_size, hidden_dim)
  69. lora_a_stacked: (num_loras, lora_rank, hidden_dim)
  70. lora_b_stacked: (num_loras, output_dim, lora_rank)
  71. indices: (batch_size)
  72. output: (batch_size, output_dim)
  73. """
  74. org_output = output
  75. x = x.view(-1, x.shape[-1])
  76. output = output.view(-1, output.shape[-1])
  77. indices = indices.view(-1)
  78. add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
  79. return output.view_as(org_output)
  80. def _apply_lora_packed_nslice(
  81. x: torch.Tensor,
  82. lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  83. lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  84. indices: torch.Tensor,
  85. output: torch.Tensor,
  86. output_slices: Tuple[int, ...],
  87. ):
  88. """Applies lora to each input.
  89. This method applies all loras to each input. It uses the
  90. indices vector to determine which lora yields the
  91. correct output. An index of -1 means no lora should be
  92. applied. This method adds the final lora results to the
  93. output.
  94. This method is used for layers that are composed of multiple sublayers
  95. (slices) packed together.
  96. Input shapes:
  97. x: (batch_size, hidden_dim)
  98. lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
  99. lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
  100. indices: (batch_size)
  101. output: (batch_size, q_slice_size + 2*kv_slice_size)
  102. output_slices: n-1 element tuple of (slice_size...),
  103. where n is number of slices
  104. """
  105. org_output = output
  106. x = x.view(-1, x.shape[-1])
  107. output = output.view(-1, output.shape[-1])
  108. indices = indices.view(-1)
  109. offset_left = 0
  110. for slice_idx in range(len(output_slices)):
  111. add_lora_slice(output, x, lora_a_stacked[slice_idx],
  112. lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
  113. output_slices[slice_idx])
  114. offset_left += output_slices[slice_idx]
  115. return output.view_as(org_output)
  116. @dataclass
  117. class LoRAMapping:
  118. # Per every token in input_ids:
  119. index_mapping: Tuple[int, ...]
  120. # Per sampled token:
  121. prompt_mapping: Tuple[int, ...]
  122. def __post_init__(self):
  123. self.index_mapping = tuple(self.index_mapping)
  124. self.prompt_mapping = tuple(self.prompt_mapping)
  125. class BaseLayerWithLoRA(nn.Module):
  126. def slice_lora_a(
  127. self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  128. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  129. """Slice lora a if splitting for tensor parallelism."""
  130. ...
  131. def slice_lora_b(
  132. self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  133. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  134. """Slice lora b if splitting with tensor parallelism."""
  135. ...
  136. def create_lora_weights(
  137. self,
  138. max_loras: int,
  139. lora_config: LoRAConfig,
  140. model_config: Optional[PretrainedConfig] = None) -> None:
  141. """Initializes lora matrices."""
  142. ...
  143. def reset_lora(self, index: int):
  144. """Resets the lora weights at index back to 0."""
  145. ...
  146. def set_lora(
  147. self,
  148. index: int,
  149. lora_a: torch.Tensor,
  150. lora_b: torch.Tensor,
  151. embeddings_tensor: Optional[torch.Tensor],
  152. ):
  153. """Overwrites lora tensors at index."""
  154. ...
  155. def set_mapping(
  156. self,
  157. base_indices: torch.Tensor,
  158. sampler_indices: torch.Tensor,
  159. sampler_indices_padded: torch.Tensor,
  160. embeddings_indices: torch.Tensor,
  161. long_lora_indices: torch.Tensor,
  162. indices_len: List[int],
  163. ):
  164. """Sets the mapping indices."""
  165. ...
  166. @classmethod
  167. def can_replace_layer(cls, source_layer: nn.Module,
  168. lora_config: LoRAConfig, packed_modules_list: List,
  169. model_config: Optional[PretrainedConfig]) -> bool:
  170. """Returns True if the layer can be replaced by this LoRA layer."""
  171. raise NotImplementedError
  172. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
  173. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  174. super().__init__()
  175. self.base_layer = base_layer
  176. self.embeddings_slice: Optional[Tuple[int, int]]
  177. self.embeddings_weights: Optional[torch.Tensor]
  178. def create_lora_weights(
  179. self,
  180. max_loras: int,
  181. lora_config: LoRAConfig,
  182. model_config: Optional[PretrainedConfig] = None) -> None:
  183. if self.base_layer.num_added_embeddings_per_partition > 0:
  184. # We can start adding lora weights
  185. self.embeddings_weights = self.base_layer.weight.data[
  186. self.base_layer.num_org_embeddings_per_partition:self.
  187. base_layer.num_org_embeddings_per_partition +
  188. self.base_layer.num_added_embeddings_per_partition]
  189. self.embeddings_slice = (
  190. self.base_layer.shard_indices.added_vocab_start_index -
  191. self.base_layer.org_vocab_size,
  192. self.base_layer.shard_indices.added_vocab_end_index -
  193. self.base_layer.org_vocab_size)
  194. self.base_layer.weight.data[
  195. self.base_layer.num_org_embeddings_per_partition:].fill_(0)
  196. else:
  197. self.embeddings_slice = None
  198. self.embeddings_weights = None
  199. self.embeddings_tensors = torch.zeros(
  200. (
  201. max_loras,
  202. lora_config.lora_extra_vocab_size,
  203. self.base_layer.embedding_dim,
  204. ),
  205. dtype=self.base_layer.weight.dtype,
  206. device=self.base_layer.weight.device,
  207. )
  208. self.lora_a_stacked = torch.zeros(
  209. (
  210. max_loras,
  211. self.base_layer.org_vocab_size +
  212. lora_config.lora_extra_vocab_size,
  213. lora_config.max_lora_rank,
  214. ),
  215. dtype=lora_config.lora_dtype,
  216. device=self.base_layer.weight.device,
  217. )
  218. self.lora_b_stacked = torch.zeros(
  219. (
  220. max_loras,
  221. 1,
  222. self.base_layer.embedding_dim,
  223. lora_config.max_lora_rank,
  224. ),
  225. dtype=lora_config.lora_dtype,
  226. device=self.base_layer.weight.device,
  227. )
  228. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  229. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  230. self.lora_a_stacked.shape[2],
  231. )
  232. # Lazily initialized.
  233. self.indices: torch.Tensor
  234. self.indices_len: List[int]
  235. self.embeddings_indices: torch.Tensor
  236. def reset_lora(self, index: int):
  237. self.lora_a_stacked[index] = 0
  238. self.lora_b_stacked[index] = 0
  239. self.embeddings_tensors[index] = 0
  240. def set_lora(
  241. self,
  242. index: int,
  243. lora_a: torch.Tensor,
  244. lora_b: torch.Tensor,
  245. embeddings_tensor: Optional[torch.Tensor],
  246. ):
  247. self.reset_lora(index)
  248. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  249. lora_a, non_blocking=True)
  250. self.lora_b_stacked[index,
  251. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  252. lora_b.T, non_blocking=True)
  253. if embeddings_tensor is not None:
  254. self.embeddings_tensors[
  255. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  256. shape[1]].copy_(embeddings_tensor, non_blocking=True)
  257. if self.embeddings_slice is not None:
  258. # TODO(yard1): Optimize this copy, we don't need to copy
  259. # everything, just the modified part
  260. embeddings = self.embeddings_tensors.view(
  261. self.embeddings_tensors.shape[0] *
  262. self.embeddings_tensors.shape[1],
  263. self.embeddings_tensors.shape[2]
  264. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  265. assert self.embeddings_weights is not None
  266. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  267. def set_mapping(
  268. self,
  269. base_indices: torch.Tensor,
  270. sampler_indices: torch.Tensor,
  271. sampler_indices_padded: torch.Tensor,
  272. embeddings_indices: torch.Tensor,
  273. long_lora_indices: torch.Tensor,
  274. indices_len: List[int],
  275. ):
  276. self.indices = base_indices
  277. self.embeddings_indices = embeddings_indices
  278. self.indices_len = indices_len
  279. def forward(self, x: torch.Tensor) -> torch.Tensor:
  280. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  281. embedding_len = self.indices_len[3]
  282. indices = self.embeddings_indices[1][:embedding_len].view_as(x)
  283. full_lora_a_embeddings = F.embedding(
  284. x + indices,
  285. self.lora_a_stacked_2d,
  286. )
  287. indices = self.embeddings_indices[0][:embedding_len].view_as(x)
  288. full_output = self.base_layer.forward(
  289. x.add_(indices * added_tokens_mask))
  290. full_output_org = full_output
  291. if full_output.ndim == 3:
  292. full_output = full_output.view(
  293. full_output.shape[0] * full_output.shape[1], -1)
  294. if full_lora_a_embeddings.ndim == 3:
  295. full_lora_a_embeddings = full_lora_a_embeddings.view(
  296. full_lora_a_embeddings.shape[0] *
  297. full_lora_a_embeddings.shape[1], -1)
  298. bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
  299. self.indices[:self.indices_len[0]], 0, 1.0)
  300. return full_output.view_as(full_output_org)
  301. @classmethod
  302. def can_replace_layer(cls, source_layer: nn.Module,
  303. lora_config: LoRAConfig, packed_modules_list: List,
  304. model_config: Optional[PretrainedConfig]) -> bool:
  305. return type(source_layer) is VocabParallelEmbedding
  306. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  307. """
  308. LoRA on top of ColumnParallelLinear layer.
  309. LoRA B is sliced for tensor parallelism.
  310. """
  311. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  312. super().__init__()
  313. self.base_layer = base_layer
  314. self.tp_size = get_tensor_model_parallel_world_size()
  315. self.input_size = self.base_layer.input_size
  316. self.output_size = self.base_layer.output_size_per_partition
  317. self.device = _get_lora_device(self.base_layer)
  318. def create_lora_weights(
  319. self,
  320. max_loras: int,
  321. lora_config: LoRAConfig,
  322. model_config: Optional[PretrainedConfig] = None) -> None:
  323. self.lora_config = lora_config
  324. self.tp_size = get_tensor_model_parallel_world_size()
  325. lora_a_output_size_per_partition = (
  326. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  327. else divide(lora_config.max_lora_rank, self.tp_size))
  328. self.lora_a_stacked = torch.zeros(
  329. max_loras,
  330. 1,
  331. lora_a_output_size_per_partition,
  332. self.input_size,
  333. dtype=lora_config.lora_dtype,
  334. device=self.device,
  335. )
  336. self.lora_b_stacked = torch.zeros(
  337. max_loras,
  338. 1,
  339. self.output_size,
  340. lora_config.max_lora_rank,
  341. dtype=lora_config.lora_dtype,
  342. device=self.device,
  343. )
  344. self.output_dim = self.lora_b_stacked.shape[2]
  345. # lazily initialized.
  346. self.indices: torch.Tensor
  347. self.indices_len: List[int]
  348. def reset_lora(self, index: int):
  349. self.lora_a_stacked[index] = 0
  350. self.lora_b_stacked[index] = 0
  351. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  352. return lora_a
  353. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  354. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  355. shard_size = self.output_dim
  356. start_idx = tensor_model_parallel_rank * shard_size
  357. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  358. lora_b = lora_b[:, start_idx:end_idx]
  359. return lora_b
  360. def set_lora(
  361. self,
  362. index: int,
  363. lora_a: torch.Tensor,
  364. lora_b: torch.Tensor,
  365. embeddings_tensor: Optional[torch.Tensor],
  366. ):
  367. self.reset_lora(index)
  368. if self.tp_size > 1:
  369. lora_a = self.slice_lora_a(lora_a)
  370. lora_b = self.slice_lora_b(lora_b)
  371. self.lora_a_stacked[index,
  372. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  373. lora_a.T, non_blocking=True)
  374. self.lora_b_stacked[index,
  375. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  376. lora_b.T, non_blocking=True)
  377. def set_mapping(
  378. self,
  379. base_indices: torch.Tensor,
  380. sampler_indices: torch.Tensor,
  381. sampler_indices_padded: torch.Tensor,
  382. embeddings_indices: torch.Tensor,
  383. long_lora_indices: torch.Tensor,
  384. indices_len: List[int],
  385. ):
  386. self.indices = base_indices
  387. self.indices_len = indices_len
  388. def apply(self, x: torch.Tensor,
  389. bias: Optional[torch.Tensor]) -> torch.Tensor:
  390. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  391. _apply_lora(
  392. x,
  393. self.lora_a_stacked,
  394. self.lora_b_stacked,
  395. self.indices[:self.indices_len[0]],
  396. output,
  397. )
  398. return output
  399. def forward(self, input_):
  400. """Forward of ColumnParallelLinear
  401. Args:
  402. input_: Tensor whose last dimension is `input_size`.
  403. Returns:
  404. - output
  405. - bias
  406. """
  407. bias = (self.base_layer.bias
  408. if not self.base_layer.skip_bias_add else None)
  409. # Matrix multiply.
  410. output_parallel = self.apply(input_, bias)
  411. if self.base_layer.gather_output:
  412. # All-gather across the partitions.
  413. output = tensor_model_parallel_all_gather(output_parallel)
  414. else:
  415. output = output_parallel
  416. output_bias = (self.base_layer.bias
  417. if self.base_layer.skip_bias_add else None)
  418. return output, output_bias
  419. @classmethod
  420. @_not_fully_sharded_can_replace
  421. def can_replace_layer(cls, source_layer: nn.Module,
  422. lora_config: LoRAConfig, packed_modules_list: List,
  423. model_config: Optional[PretrainedConfig]) -> bool:
  424. return type(source_layer) is ColumnParallelLinear or (
  425. type(source_layer) is MergedColumnParallelLinear
  426. and len(packed_modules_list) == 1)
  427. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  428. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  429. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  430. This means we have 2 LoRAs, each applied to one half of the layer.
  431. Both slices must have the same size.
  432. """
  433. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  434. super().__init__(base_layer)
  435. def create_lora_weights(
  436. self,
  437. max_loras: int,
  438. lora_config: LoRAConfig,
  439. model_config: Optional[PretrainedConfig] = None) -> None:
  440. self.lora_config = lora_config
  441. n_slices = 2
  442. if not (len(self.base_layer.output_sizes) == n_slices
  443. and self.base_layer.output_sizes[0]
  444. == self.base_layer.output_sizes[1]):
  445. raise ValueError(
  446. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  447. "the same size.")
  448. self.tp_size = get_tensor_model_parallel_world_size()
  449. self.tp_rank = get_tensor_model_parallel_rank()
  450. lora_a_output_size_per_partition = (
  451. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  452. else divide(lora_config.max_lora_rank, self.tp_size))
  453. self.lora_a_stacked = tuple(
  454. torch.zeros(
  455. max_loras,
  456. 1,
  457. lora_a_output_size_per_partition,
  458. self.input_size,
  459. dtype=lora_config.lora_dtype,
  460. device=self.device,
  461. ) for _ in range(n_slices))
  462. self.lora_b_stacked = tuple(
  463. torch.zeros(
  464. max_loras,
  465. 1,
  466. self.output_size // 2,
  467. lora_config.max_lora_rank,
  468. dtype=lora_config.lora_dtype,
  469. device=self.device,
  470. ) for _ in range(n_slices))
  471. self.output_dim = self.lora_b_stacked[0].shape[2]
  472. # Lazily initialized.
  473. self.indices: torch.Tensor
  474. def reset_lora(self, index: int):
  475. self.lora_a_stacked[0][index] = 0
  476. self.lora_a_stacked[1][index] = 0
  477. self.lora_b_stacked[0][index] = 0
  478. self.lora_b_stacked[1][index] = 0
  479. def slice_lora_a(
  480. self, lora_a: List[Union[torch.Tensor, None]]
  481. ) -> List[Union[torch.Tensor, None]]:
  482. return lora_a
  483. def slice_lora_b(
  484. self, lora_b: List[Union[torch.Tensor, None]]
  485. ) -> List[Union[torch.Tensor, None]]:
  486. if lora_b[0] is None or lora_b[1] is None:
  487. return lora_b
  488. shard_size = self.output_dim
  489. start_idx = self.tp_rank * shard_size
  490. end_idx = (self.tp_rank + 1) * shard_size
  491. lora_b = [
  492. lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx]
  493. ]
  494. return lora_b
  495. def set_lora(
  496. self,
  497. index: int,
  498. lora_a: torch.Tensor,
  499. lora_b: torch.Tensor,
  500. embeddings_tensor: Optional[torch.Tensor],
  501. ):
  502. self.reset_lora(index)
  503. if self.tp_size > 1:
  504. lora_a = self.slice_lora_a(lora_a)
  505. lora_b = self.slice_lora_b(lora_b)
  506. if lora_a[0] is not None:
  507. self.lora_a_stacked[0][
  508. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  509. lora_a[0].T, non_blocking=True)
  510. self.lora_b_stacked[0][
  511. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  512. lora_b[0].T, non_blocking=True)
  513. if lora_a[1] is not None:
  514. self.lora_a_stacked[1][
  515. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  516. lora_a[1].T, non_blocking=True)
  517. self.lora_b_stacked[1][
  518. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  519. lora_b[1].T, non_blocking=True)
  520. def apply(self, x: torch.Tensor,
  521. bias: Optional[torch.Tensor]) -> torch.Tensor:
  522. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  523. _apply_lora_packed_nslice(
  524. x,
  525. self.lora_a_stacked,
  526. self.lora_b_stacked,
  527. self.indices[:self.indices_len[0]],
  528. output,
  529. (self.output_dim, self.output_dim),
  530. )
  531. return output
  532. @classmethod
  533. @_not_fully_sharded_can_replace
  534. def can_replace_layer(cls, source_layer: nn.Module,
  535. lora_config: LoRAConfig, packed_modules_list: List,
  536. model_config: Optional[PretrainedConfig]) -> bool:
  537. return type(source_layer) is MergedColumnParallelLinear and len(
  538. packed_modules_list) == 2
  539. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  540. """
  541. ColumnParallelLinear layer that is specifically designed for
  542. qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
  543. only contains a single LoRA within their qkv_proj layer.
  544. During inference with Tensor Parallel, the weights of lora_b
  545. must be accurately partitioned according to the respective ranks.
  546. Q slice may have different shape than K and V slices (which both have
  547. the same shape).
  548. """
  549. def __init__(self, base_layer: QKVParallelLinear) -> None:
  550. super().__init__(base_layer)
  551. self.tp_size = get_tensor_model_parallel_world_size()
  552. self.q_proj_total_size = (self.base_layer.total_num_heads *
  553. self.base_layer.head_size)
  554. self.q_proj_shard_size = (self.base_layer.num_heads *
  555. self.base_layer.head_size)
  556. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  557. self.base_layer.head_size)
  558. self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
  559. self.base_layer.head_size)
  560. def set_lora(
  561. self,
  562. index: int,
  563. lora_a: torch.Tensor,
  564. lora_b: torch.Tensor,
  565. embeddings_tensor: Optional[torch.Tensor],
  566. ):
  567. self.reset_lora(index)
  568. if self.tp_size > 1:
  569. tp_rank = get_tensor_model_parallel_rank()
  570. self.q_shard_id = tp_rank
  571. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  572. lora_b_q = lora_b[:, self.q_proj_shard_size *
  573. self.q_shard_id:self.q_proj_shard_size *
  574. (self.q_shard_id + 1)]
  575. k_offset = self.q_proj_total_size
  576. lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
  577. self.kv_shard_id:k_offset +
  578. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  579. v_offset = k_offset + self.kv_proj_total_size
  580. lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
  581. self.kv_shard_id:v_offset +
  582. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  583. lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
  584. self.lora_a_stacked[index,
  585. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  586. lora_a.T, non_blocking=True)
  587. self.lora_b_stacked[index,
  588. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  589. lora_b.T, non_blocking=True)
  590. @classmethod
  591. def can_replace_layer(cls, source_layer: nn.Module,
  592. lora_config: LoRAConfig, packed_modules_list: List,
  593. model_config: Optional[PretrainedConfig]) -> bool:
  594. return type(source_layer) is QKVParallelLinear and len(
  595. packed_modules_list) == 1
  596. class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  597. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  598. packed together in qkv proj fashion
  599. (q_proj + k_proj + v_proj -> qkv_proj).
  600. This means we have 3 LoRAs, each applied to one slice of the layer.
  601. Q slice may have different shape than K and V slices (which both have
  602. the same shape).
  603. """
  604. def __init__(self, base_layer: QKVParallelLinear) -> None:
  605. super().__init__(base_layer)
  606. def create_lora_weights(
  607. self,
  608. max_loras: int,
  609. lora_config: LoRAConfig,
  610. model_config: Optional[PretrainedConfig] = None) -> None:
  611. self.lora_config = lora_config
  612. self.tp_size = get_tensor_model_parallel_world_size()
  613. self.tp_rank = get_tensor_model_parallel_rank()
  614. self.q_proj_shard_size = (self.base_layer.num_heads *
  615. self.base_layer.head_size)
  616. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  617. self.base_layer.head_size)
  618. self.q_shard_id = self.tp_rank
  619. self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
  620. lora_a_output_size_per_partition = (
  621. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  622. else divide(lora_config.max_lora_rank, self.tp_size))
  623. # q, k, v
  624. self.lora_a_stacked = (
  625. torch.zeros(
  626. max_loras,
  627. 1,
  628. lora_a_output_size_per_partition,
  629. self.input_size,
  630. dtype=lora_config.lora_dtype,
  631. device=self.device,
  632. ),
  633. torch.zeros(
  634. max_loras,
  635. 1,
  636. lora_a_output_size_per_partition,
  637. self.input_size,
  638. dtype=lora_config.lora_dtype,
  639. device=self.device,
  640. ),
  641. torch.zeros(
  642. max_loras,
  643. 1,
  644. lora_a_output_size_per_partition,
  645. self.input_size,
  646. dtype=lora_config.lora_dtype,
  647. device=self.device,
  648. ),
  649. )
  650. self.lora_b_stacked = (
  651. torch.zeros(
  652. max_loras,
  653. 1,
  654. self.q_proj_shard_size,
  655. lora_config.max_lora_rank,
  656. dtype=lora_config.lora_dtype,
  657. device=self.device,
  658. ),
  659. torch.zeros(
  660. max_loras,
  661. 1,
  662. self.kv_proj_shard_size,
  663. lora_config.max_lora_rank,
  664. dtype=lora_config.lora_dtype,
  665. device=self.device,
  666. ),
  667. torch.zeros(
  668. max_loras,
  669. 1,
  670. self.kv_proj_shard_size,
  671. lora_config.max_lora_rank,
  672. dtype=lora_config.lora_dtype,
  673. device=self.device,
  674. ),
  675. )
  676. self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
  677. self.kv_proj_shard_size)
  678. self.packed_indices: Optional[torch.Tensor] = None
  679. self.standard_indices: Optional[torch.Tensor] = None
  680. # lazily initialized.
  681. self.indices_len: List[int]
  682. def reset_lora(self, index: int):
  683. self.lora_a_stacked[0][index] = 0
  684. self.lora_b_stacked[0][index] = 0
  685. self.lora_a_stacked[1][index] = 0
  686. self.lora_b_stacked[1][index] = 0
  687. self.lora_a_stacked[2][index] = 0
  688. self.lora_b_stacked[2][index] = 0
  689. def slice_lora_a(
  690. self, lora_a: List[Union[torch.Tensor, None]]
  691. ) -> List[Union[torch.Tensor, None]]:
  692. return lora_a
  693. def slice_lora_b(
  694. self, lora_b: List[Union[torch.Tensor, None]]
  695. ) -> List[Union[torch.Tensor, None]]:
  696. lora_b_q, lora_b_k, lora_b_v = None, None, None
  697. if lora_b[0] is not None:
  698. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  699. self.q_shard_id:self.q_proj_shard_size *
  700. (self.q_shard_id + 1)]
  701. if lora_b[1] is not None:
  702. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  703. self.kv_shard_id:self.kv_proj_shard_size *
  704. (self.kv_shard_id + 1)]
  705. if lora_b[2] is not None:
  706. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  707. self.kv_shard_id:self.kv_proj_shard_size *
  708. (self.kv_shard_id + 1)]
  709. lora_b = [lora_b_q, lora_b_k, lora_b_v]
  710. return lora_b
  711. def set_lora(
  712. self,
  713. index: int,
  714. lora_a: torch.Tensor,
  715. lora_b: torch.Tensor,
  716. embeddings_tensor: Optional[torch.Tensor],
  717. ):
  718. self.reset_lora(index)
  719. if self.tp_size > 1:
  720. lora_a = self.slice_lora_a(lora_a)
  721. lora_b = self.slice_lora_b(lora_b)
  722. if lora_b[0] is not None:
  723. lora_b_q = lora_b[0]
  724. self.lora_b_stacked[0][
  725. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  726. lora_b_q.T, non_blocking=True)
  727. if lora_b[1] is not None:
  728. lora_b_k = lora_b[1]
  729. self.lora_b_stacked[1][
  730. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  731. lora_b_k.T, non_blocking=True)
  732. if lora_b[2] is not None:
  733. lora_b_v = lora_b[2]
  734. self.lora_b_stacked[2][
  735. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  736. lora_b_v.T, non_blocking=True)
  737. if lora_a[0] is not None:
  738. self.lora_a_stacked[0][
  739. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  740. lora_a[0].T, non_blocking=True)
  741. if lora_a[1] is not None:
  742. self.lora_a_stacked[1][
  743. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  744. lora_a[1].T, non_blocking=True)
  745. if lora_a[2] is not None:
  746. self.lora_a_stacked[2][
  747. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  748. lora_a[2].T, non_blocking=True)
  749. def apply(self, x: torch.Tensor,
  750. bias: Optional[torch.Tensor]) -> torch.Tensor:
  751. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  752. _apply_lora_packed_nslice(
  753. x,
  754. self.lora_a_stacked,
  755. self.lora_b_stacked,
  756. self.indices[:self.indices_len[0]],
  757. output,
  758. self.output_slices,
  759. )
  760. return output
  761. @classmethod
  762. @_not_fully_sharded_can_replace
  763. def can_replace_layer(cls, source_layer: nn.Module,
  764. lora_config: LoRAConfig, packed_modules_list: List,
  765. model_config: Optional[PretrainedConfig]) -> bool:
  766. return type(source_layer) is QKVParallelLinear and len(
  767. packed_modules_list) == 3
  768. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  769. def __init__(self, base_layer: RowParallelLinear) -> None:
  770. super().__init__()
  771. self.base_layer = base_layer
  772. self.input_size = self.base_layer.input_size_per_partition
  773. self.output_size = self.base_layer.output_size
  774. self.device = _get_lora_device(self.base_layer)
  775. def create_lora_weights(
  776. self,
  777. max_loras: int,
  778. lora_config: LoRAConfig,
  779. model_config: Optional[PretrainedConfig] = None) -> None:
  780. self.lora_config = lora_config
  781. self.tp_rank = get_tensor_model_parallel_rank()
  782. self.lora_a_stacked = torch.zeros(
  783. (
  784. max_loras,
  785. 1,
  786. lora_config.max_lora_rank,
  787. self.input_size,
  788. ),
  789. dtype=lora_config.lora_dtype,
  790. device=self.device,
  791. )
  792. tp_size = get_tensor_model_parallel_world_size()
  793. lora_b_output_size_per_partition = (
  794. self.output_size if not lora_config.fully_sharded_loras else
  795. divide(self.output_size, tp_size))
  796. self.lora_b_stacked = torch.zeros(
  797. (
  798. max_loras,
  799. 1,
  800. lora_b_output_size_per_partition,
  801. lora_config.max_lora_rank,
  802. ),
  803. dtype=lora_config.lora_dtype,
  804. device=self.device,
  805. )
  806. # Lazily initialized
  807. self.indices: torch.Tensor
  808. self.indices_len: List[int]
  809. def reset_lora(self, index: int):
  810. self.lora_a_stacked[index] = 0
  811. self.lora_b_stacked[index] = 0
  812. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  813. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  814. shard_size = self.input_size
  815. start_idx = tensor_model_parallel_rank * shard_size
  816. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  817. lora_a = lora_a[start_idx:end_idx, :]
  818. return lora_a
  819. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  820. return lora_b
  821. def set_lora(
  822. self,
  823. index: int,
  824. lora_a: torch.Tensor,
  825. lora_b: torch.Tensor,
  826. embeddings_tensor: Optional[torch.Tensor],
  827. ):
  828. self.reset_lora(index)
  829. if self.base_layer.tp_size > 1:
  830. lora_a = self.slice_lora_a(lora_a)
  831. lora_b = self.slice_lora_b(lora_b)
  832. self.lora_a_stacked[index,
  833. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  834. lora_a.T, non_blocking=True)
  835. self.lora_b_stacked[index,
  836. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  837. lora_b.T, non_blocking=True)
  838. def set_mapping(
  839. self,
  840. base_indices: torch.Tensor,
  841. sampler_indices: torch.Tensor,
  842. sampler_indices_padded: torch.Tensor,
  843. embeddings_indices: torch.Tensor,
  844. long_lora_indices: torch.Tensor,
  845. indices_len: List[int],
  846. ):
  847. self.indices = base_indices
  848. self.indices_len = indices_len
  849. def apply(self, x: torch.Tensor) -> torch.Tensor:
  850. output = self.base_layer.quant_method.apply(self.base_layer, x)
  851. _apply_lora(
  852. x,
  853. self.lora_a_stacked,
  854. self.lora_b_stacked,
  855. self.indices[:self.indices_len[0]],
  856. output,
  857. )
  858. return output
  859. def forward(self, input_):
  860. """Forward of RowParallelLinear
  861. Args:
  862. input_: tensor whose last dimension is `input_size`. If
  863. `input_is_parallel` is set, then the last dimension
  864. is `input_size // tp_size`.
  865. Returns:
  866. - output
  867. - bias
  868. """
  869. # Set up backprop all-reduce.
  870. if self.base_layer.input_is_parallel:
  871. input_parallel = input_
  872. else:
  873. # TODO: simplify code below
  874. tp_rank = get_tensor_model_parallel_rank()
  875. splitted_input = split_tensor_along_last_dim(
  876. input_, num_partitions=self.base_layer.tp_size)
  877. input_parallel = splitted_input[tp_rank].contiguous()
  878. # Matrix multiply.
  879. output_parallel = self.apply(input_parallel)
  880. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  881. output_ = tensor_model_parallel_all_reduce(output_parallel)
  882. else:
  883. output_ = output_parallel
  884. if not self.base_layer.skip_bias_add:
  885. output = (output_ + self.base_layer.bias
  886. if self.base_layer.bias is not None else output_)
  887. output_bias = None
  888. else:
  889. output = output_
  890. output_bias = self.base_layer.bias
  891. return output, output_bias
  892. @property
  893. def weight(self):
  894. return self.base_layer.weight if hasattr(
  895. self.base_layer, "weight") else self.base_layer.qweight
  896. @classmethod
  897. @_not_fully_sharded_can_replace
  898. def can_replace_layer(cls, source_layer: nn.Module,
  899. lora_config: LoRAConfig, packed_modules_list: List,
  900. model_config: Optional[PretrainedConfig]) -> bool:
  901. return type(source_layer) is RowParallelLinear
  902. class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
  903. """
  904. LoRA wrapper for LogitsProcessor, with extra logic to handle the
  905. application of the LoRA adapter and added LoRA vocabulary.
  906. Args:
  907. base_layer: LogitsProcessor layer
  908. hidden_size: hidden size of the model
  909. dtype: data type of the model
  910. device: device of the model
  911. sharded_to_full_mapping: index mapping from sharded vocab to full vocab
  912. received from base_layer.get_sharded_to_full_mapping(). If None,
  913. no reindexing will be done.
  914. """
  915. def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
  916. dtype: torch.dtype, device: torch.device,
  917. sharded_to_full_mapping: Optional[List[int]]) -> None:
  918. super().__init__()
  919. self.base_layer = base_layer
  920. self.hidden_size = hidden_size
  921. self.dtype = dtype
  922. self.device = device
  923. self.tp_size = get_tensor_model_parallel_world_size()
  924. self.tp_rank = get_tensor_model_parallel_rank()
  925. self.sharded_to_full_mapping = sharded_to_full_mapping
  926. @property
  927. def logits_as_input(self):
  928. return self.base_layer.logits_as_input
  929. @property
  930. def vocab_size(self):
  931. return self.base_layer.vocab_size
  932. @property
  933. def scale(self):
  934. return self.base_layer.scale
  935. @property
  936. def org_vocab_size(self):
  937. return self.base_layer.org_vocab_size
  938. @property
  939. def include_gpu_probs_tensor(self):
  940. return self.base_layer.include_gpu_probs_tensor
  941. def create_lora_weights(
  942. self,
  943. max_loras: int,
  944. lora_config: LoRAConfig,
  945. model_config: Optional[PretrainedConfig] = None,
  946. ) -> None:
  947. # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
  948. if 32000 < self.base_layer.vocab_size > 128512:
  949. raise ValueError("When using LoRA, vocab size must be "
  950. "32000 >= vocab_size <= 128512")
  951. self.lora_a_stacked = torch.zeros(
  952. (
  953. max_loras,
  954. 1,
  955. lora_config.max_lora_rank,
  956. self.hidden_size,
  957. ),
  958. dtype=lora_config.lora_dtype,
  959. device=self.device,
  960. )
  961. self.lora_b_stacked = torch.zeros(
  962. (
  963. max_loras,
  964. 1,
  965. # Pad for kernel compatibility
  966. math.ceil(self.base_layer.vocab_size /
  967. lora_config.lora_vocab_padding_size) *
  968. lora_config.lora_vocab_padding_size,
  969. lora_config.max_lora_rank,
  970. ),
  971. dtype=lora_config.lora_dtype,
  972. device=self.device,
  973. )
  974. self.embeddings_tensors = torch.full(
  975. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  976. fill_value=float("-inf"),
  977. dtype=self.dtype,
  978. device=self.device,
  979. )
  980. if self.sharded_to_full_mapping is not None:
  981. self.sharded_to_full_mapping_gpu = torch.tensor(
  982. self.sharded_to_full_mapping,
  983. device=self.device,
  984. dtype=torch.long)
  985. else:
  986. self.sharded_to_full_mapping_gpu = None
  987. # Lazily initialized.
  988. self.indices: torch.Tensor
  989. self.indices_len: List[int]
  990. self.indices_padded: torch.Tensor
  991. def reset_lora(self, index: int):
  992. self.lora_a_stacked[index] = 0
  993. self.lora_b_stacked[index] = 0
  994. self.embeddings_tensors[index] = float("-inf")
  995. def set_lora(
  996. self,
  997. index: int,
  998. lora_a: torch.Tensor,
  999. lora_b: torch.Tensor,
  1000. embeddings_tensor: Optional[torch.Tensor],
  1001. ):
  1002. self.reset_lora(index)
  1003. self.lora_a_stacked[index,
  1004. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  1005. lora_a.T, non_blocking=True)
  1006. self.lora_b_stacked[index,
  1007. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  1008. lora_b.T, non_blocking=True)
  1009. if embeddings_tensor is not None:
  1010. self.embeddings_tensors[
  1011. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  1012. shape[1], ] = embeddings_tensor
  1013. def set_mapping(
  1014. self,
  1015. base_indices: torch.Tensor,
  1016. sampler_indices: torch.Tensor,
  1017. sampler_indices_padded: torch.Tensor,
  1018. embeddings_indices: torch.Tensor,
  1019. long_lora_indices: torch.Tensor,
  1020. indices_len: List[int],
  1021. ):
  1022. self.indices = sampler_indices
  1023. self.indices_padded = sampler_indices_padded
  1024. self.indices_len = indices_len
  1025. def _get_logits(
  1026. self,
  1027. hidden_states: torch.Tensor,
  1028. embedding: torch.Tensor,
  1029. embedding_bias: Optional[torch.Tensor] = None,
  1030. ) -> Optional[torch.Tensor]:
  1031. # Get the logits for the next tokens.
  1032. logits = torch.matmul(hidden_states, embedding.t())
  1033. if embedding_bias is not None:
  1034. logits += embedding_bias
  1035. logits = tensor_model_parallel_gather(logits)
  1036. if logits is None:
  1037. return None
  1038. if self.sharded_to_full_mapping_gpu is not None:
  1039. # Reindex full logits tensor to ensure 1:1 mapping between
  1040. # index and token_id
  1041. # Example for:
  1042. # org_vocab_size = 4
  1043. # added_vocab_size = 2
  1044. # pad_to_size = 8
  1045. # tp_size = 2
  1046. # indices: [0, 1, 2, 3, 4, 5, 6, 7]
  1047. # token_id: [0, 1, 4, -1, 2, 3, 5, -1]
  1048. # Therefore, the mapping is expected to be:
  1049. # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
  1050. # we get:
  1051. # indices: [0, 1, 2, 3, 4, 5, 6, 7]
  1052. # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
  1053. logits = logits[:, self.sharded_to_full_mapping_gpu]
  1054. lora_logits = torch.empty(
  1055. self.embeddings_tensors.shape[0] + 1,
  1056. self.embeddings_tensors.shape[1],
  1057. hidden_states.shape[0],
  1058. dtype=self.embeddings_tensors.dtype,
  1059. device=self.embeddings_tensors.device,
  1060. )
  1061. torch.matmul(self.embeddings_tensors,
  1062. hidden_states.T,
  1063. out=lora_logits[:-1])
  1064. lora_logits[-1] = float("-inf")
  1065. lora_logits = lora_logits.mT
  1066. lora_logits = (lora_logits.reshape(
  1067. lora_logits.shape[0] * lora_logits.shape[1],
  1068. lora_logits.shape[2],
  1069. ).index_select(0,
  1070. self.indices_padded[:self.indices_len[2]]).nan_to_num_(
  1071. nan=float("-inf"),
  1072. posinf=float("inf"),
  1073. neginf=float("-inf")))
  1074. logits[:,
  1075. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  1076. lora_logits.shape[1]] = lora_logits
  1077. _apply_lora(
  1078. hidden_states,
  1079. self.lora_a_stacked,
  1080. self.lora_b_stacked,
  1081. self.indices[:self.indices_len[1]],
  1082. logits,
  1083. )
  1084. # Remove paddings in vocab (if any).
  1085. logits = logits[:, :self.base_layer.vocab_size]
  1086. return logits
  1087. def forward(self, *args, **kwargs):
  1088. return type(self.base_layer).forward(self, *args, **kwargs)
  1089. @classmethod
  1090. def can_replace_layer(cls, source_layer: nn.Module,
  1091. lora_config: LoRAConfig, packed_modules_list: List,
  1092. model_config: Optional[PretrainedConfig]) -> bool:
  1093. # Special handling for the LogitsProcessor.
  1094. return False
  1095. class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
  1096. """Implements RoPE-scaled embeddings with linear scaling for
  1097. multiple LoRA adapters with a specialized kernel.
  1098. Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
  1099. which can handle multi lora adapters in a specialied kernel.
  1100. """
  1101. def __init__(self, base_layer: RotaryEmbedding) -> None:
  1102. super().__init__()
  1103. self.base_layer = base_layer
  1104. # Lazily initialized
  1105. self.long_lora_indices: torch.Tensor
  1106. self.indices_len: List[int]
  1107. @property
  1108. def scaling_factors(self):
  1109. return self.base_layer.scaling_factors
  1110. @property
  1111. def rotary_dim(self):
  1112. return self.base_layer.rotary_dim
  1113. def create_lora_weights(
  1114. self,
  1115. max_loras: int,
  1116. lora_config: LoRAConfig,
  1117. model_config: Optional[PretrainedConfig] = None,
  1118. ) -> None:
  1119. scaling_factors = list(
  1120. lora_config.long_lora_scaling_factors
  1121. ) if lora_config.long_lora_scaling_factors else []
  1122. base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
  1123. self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
  1124. scaling_factors = sorted(
  1125. list(set([base_scaling_factor] + scaling_factors)))
  1126. self.base_layer = LinearScalingRotaryEmbedding(
  1127. self.base_layer.head_size,
  1128. self.base_layer.rotary_dim,
  1129. self.base_layer.max_position_embeddings,
  1130. self.base_layer.base,
  1131. self.base_layer.is_neox_style,
  1132. scaling_factors,
  1133. self.base_layer.dtype,
  1134. )
  1135. def reset_lora(self, index: int):
  1136. ...
  1137. def set_lora(
  1138. self,
  1139. index: int,
  1140. lora_a: torch.Tensor,
  1141. lora_b: torch.Tensor,
  1142. embeddings_tensor: Optional[torch.Tensor],
  1143. ):
  1144. ...
  1145. def set_mapping(
  1146. self,
  1147. base_indices: torch.Tensor,
  1148. sampler_indices: torch.Tensor,
  1149. sampler_indices_padded: torch.Tensor,
  1150. embeddings_indices: torch.Tensor,
  1151. long_lora_indices: torch.Tensor,
  1152. indices_len: List[int],
  1153. ):
  1154. self.long_lora_indices = long_lora_indices
  1155. self.indices_len = indices_len
  1156. def forward(
  1157. self,
  1158. positions: torch.Tensor,
  1159. query: torch.Tensor,
  1160. key: torch.Tensor,
  1161. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1162. return self.base_layer(
  1163. positions,
  1164. query,
  1165. key,
  1166. offsets=self.long_lora_indices[:self.indices_len[4]])
  1167. @property
  1168. def scaling_factor_to_offset(self) -> Dict[float, int]:
  1169. return self.base_layer.scaling_factor_to_offset
  1170. @classmethod
  1171. def can_replace_layer(cls, source_layer: nn.Module,
  1172. lora_config: LoRAConfig, packed_modules_list: List,
  1173. model_config: Optional[PretrainedConfig]) -> bool:
  1174. """Returns True if the layer can be replaced by this LoRA layer."""
  1175. return type(source_layer) is LinearScalingRotaryEmbedding or type(
  1176. source_layer) is RotaryEmbedding
  1177. def extra_repr(self) -> str:
  1178. return self.base_layer.extra_repr()