layers.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343
  1. # pylint: disable=unused-argument
  2. import math
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from transformers import PretrainedConfig
  9. from aphrodite.common.config import LoRAConfig
  10. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  11. get_tensor_model_parallel_world_size,
  12. split_tensor_along_last_dim,
  13. tensor_model_parallel_all_gather,
  14. tensor_model_parallel_all_reduce,
  15. tensor_model_parallel_gather)
  16. from aphrodite.distributed.utils import divide
  17. from aphrodite.lora.punica import add_lora, add_lora_slice, bgmv
  18. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  19. MergedColumnParallelLinear,
  20. QKVParallelLinear,
  21. RowParallelLinear)
  22. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  23. from aphrodite.modeling.layers.rotary_embedding import (
  24. LinearScalingRotaryEmbedding, RotaryEmbedding)
  25. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  26. VocabParallelEmbedding
  27. if TYPE_CHECKING:
  28. pass
  29. def _get_lora_device(base_layer: nn.Module) -> torch.device:
  30. # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
  31. """Returns the device for where to place the LoRA tensors."""
  32. # unquantizedLinear
  33. if hasattr(base_layer, "weight"):
  34. return base_layer.weight.device
  35. # GPTQ/AWQ/SqueezeLLM
  36. elif hasattr(base_layer, "qweight"):
  37. return base_layer.qweight.device
  38. # marlin
  39. elif hasattr(base_layer, "B"):
  40. return base_layer.B.device
  41. else:
  42. raise ValueError(f"Unsupported base layer: {base_layer}")
  43. def _not_fully_sharded_can_replace(can_replace):
  44. """
  45. decorator which adds the condition of not using fully sharded loras
  46. intended to wrap can_replace_layer()
  47. """
  48. def dec(*args, **kwargs):
  49. decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True
  50. condition = (not kwargs['lora_config'].fully_sharded_loras
  51. if decorate else True)
  52. return can_replace(*args, **kwargs) and condition
  53. return dec
  54. def _apply_lora(
  55. x: torch.Tensor,
  56. lora_a_stacked: torch.Tensor,
  57. lora_b_stacked: torch.Tensor,
  58. indices: torch.Tensor,
  59. output: torch.Tensor,
  60. ):
  61. """Applies lora to each input.
  62. This method applies all loras to each input. It uses the
  63. indices vector to determine which lora yields the
  64. correct output. An index of -1 means no lora should be
  65. applied. This method adds the final lora results to the
  66. output.
  67. Input shapes:
  68. x: (batch_size, hidden_dim)
  69. lora_a_stacked: (num_loras, lora_rank, hidden_dim)
  70. lora_b_stacked: (num_loras, output_dim, lora_rank)
  71. indices: (batch_size)
  72. output: (batch_size, output_dim)
  73. """
  74. org_output = output
  75. x = x.view(-1, x.shape[-1])
  76. output = output.view(-1, output.shape[-1])
  77. indices = indices.view(-1)
  78. add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
  79. return output.view_as(org_output)
  80. def _apply_lora_packed_nslice(
  81. x: torch.Tensor,
  82. lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  83. lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  84. indices: torch.Tensor,
  85. output: torch.Tensor,
  86. output_slices: Tuple[int, ...],
  87. ):
  88. """Applies lora to each input.
  89. This method applies all loras to each input. It uses the
  90. indices vector to determine which lora yields the
  91. correct output. An index of -1 means no lora should be
  92. applied. This method adds the final lora results to the
  93. output.
  94. This method is used for layers that are composed of multiple sublayers
  95. (slices) packed together.
  96. Input shapes:
  97. x: (batch_size, hidden_dim)
  98. lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
  99. lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
  100. indices: (batch_size)
  101. output: (batch_size, q_slice_size + 2*kv_slice_size)
  102. output_slices: n-1 element tuple of (slice_size...),
  103. where n is number of slices
  104. """
  105. org_output = output
  106. x = x.view(-1, x.shape[-1])
  107. output = output.view(-1, output.shape[-1])
  108. indices = indices.view(-1)
  109. offset_left = 0
  110. for slice_idx in range(len(output_slices)):
  111. add_lora_slice(output, x, lora_a_stacked[slice_idx],
  112. lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
  113. output_slices[slice_idx])
  114. offset_left += output_slices[slice_idx]
  115. return output.view_as(org_output)
  116. @dataclass
  117. class LoRAMapping:
  118. # Per every token in input_ids:
  119. index_mapping: Tuple[int, ...]
  120. # Per sampled token:
  121. prompt_mapping: Tuple[int, ...]
  122. def __post_init__(self):
  123. self.index_mapping = tuple(self.index_mapping)
  124. self.prompt_mapping = tuple(self.prompt_mapping)
  125. class BaseLayerWithLoRA(nn.Module):
  126. def slice_lora_a(
  127. self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  128. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  129. """Slice lora a if splitting for tensor parallelism."""
  130. ...
  131. def slice_lora_b(
  132. self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  133. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  134. """Slice lora b if splitting with tensor parallelism."""
  135. ...
  136. def create_lora_weights(
  137. self,
  138. max_loras: int,
  139. lora_config: LoRAConfig,
  140. model_config: Optional[PretrainedConfig] = None) -> None:
  141. """Initializes lora matrices."""
  142. ...
  143. def reset_lora(self, index: int):
  144. """Resets the lora weights at index back to 0."""
  145. ...
  146. def set_lora(
  147. self,
  148. index: int,
  149. lora_a: torch.Tensor,
  150. lora_b: torch.Tensor,
  151. embeddings_tensor: Optional[torch.Tensor],
  152. ):
  153. """Overwrites lora tensors at index."""
  154. ...
  155. def set_mapping(
  156. self,
  157. base_indices: torch.Tensor,
  158. sampler_indices: torch.Tensor,
  159. sampler_indices_padded: torch.Tensor,
  160. embeddings_indices: torch.Tensor,
  161. long_lora_indices: torch.Tensor,
  162. indices_len: List[int],
  163. ):
  164. """Sets the mapping indices."""
  165. ...
  166. @classmethod
  167. def can_replace_layer(cls, source_layer: nn.Module,
  168. lora_config: LoRAConfig, packed_modules_list: List,
  169. model_config: Optional[PretrainedConfig]) -> bool:
  170. """Returns True if the layer can be replaced by this LoRA layer."""
  171. raise NotImplementedError
  172. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
  173. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  174. super().__init__()
  175. self.base_layer = base_layer
  176. self.embeddings_slice: Optional[Tuple[int, int]]
  177. self.embeddings_weights: Optional[torch.Tensor]
  178. def create_lora_weights(
  179. self,
  180. max_loras: int,
  181. lora_config: LoRAConfig,
  182. model_config: Optional[PretrainedConfig] = None) -> None:
  183. if self.base_layer.num_added_embeddings_per_partition > 0:
  184. # We can start adding lora weights
  185. self.embeddings_weights = self.base_layer.weight.data[
  186. self.base_layer.num_org_embeddings_per_partition:self.
  187. base_layer.num_org_embeddings_per_partition +
  188. self.base_layer.num_added_embeddings_per_partition]
  189. self.embeddings_slice = (
  190. self.base_layer.shard_indices.added_vocab_start_index -
  191. self.base_layer.org_vocab_size,
  192. self.base_layer.shard_indices.added_vocab_end_index -
  193. self.base_layer.org_vocab_size)
  194. self.base_layer.weight.data[
  195. self.base_layer.num_org_embeddings_per_partition:].fill_(0)
  196. else:
  197. self.embeddings_slice = None
  198. self.embeddings_weights = None
  199. self.embeddings_tensors = torch.zeros(
  200. (
  201. max_loras,
  202. lora_config.lora_extra_vocab_size,
  203. self.base_layer.embedding_dim,
  204. ),
  205. dtype=self.base_layer.weight.dtype,
  206. device=self.base_layer.weight.device,
  207. )
  208. self.lora_a_stacked = torch.zeros(
  209. (
  210. max_loras,
  211. self.base_layer.org_vocab_size +
  212. lora_config.lora_extra_vocab_size,
  213. lora_config.max_lora_rank,
  214. ),
  215. dtype=lora_config.lora_dtype,
  216. device=self.base_layer.weight.device,
  217. )
  218. self.lora_b_stacked = torch.zeros(
  219. (
  220. max_loras,
  221. 1,
  222. self.base_layer.embedding_dim,
  223. lora_config.max_lora_rank,
  224. ),
  225. dtype=lora_config.lora_dtype,
  226. device=self.base_layer.weight.device,
  227. )
  228. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  229. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  230. self.lora_a_stacked.shape[2],
  231. )
  232. # Lazily initialized.
  233. self.indices: torch.Tensor
  234. self.indices_len: List[int]
  235. self.embeddings_indices: torch.Tensor
  236. def reset_lora(self, index: int):
  237. self.lora_a_stacked[index] = 0
  238. self.lora_b_stacked[index] = 0
  239. self.embeddings_tensors[index] = 0
  240. def set_lora(
  241. self,
  242. index: int,
  243. lora_a: torch.Tensor,
  244. lora_b: torch.Tensor,
  245. embeddings_tensor: Optional[torch.Tensor],
  246. ):
  247. self.reset_lora(index)
  248. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  249. lora_a, non_blocking=True)
  250. self.lora_b_stacked[index,
  251. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  252. lora_b.T, non_blocking=True)
  253. if embeddings_tensor is not None:
  254. self.embeddings_tensors[
  255. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  256. shape[1]].copy_(embeddings_tensor, non_blocking=True)
  257. if self.embeddings_slice is not None:
  258. # TODO(yard1): Optimize this copy, we don't need to copy
  259. # everything, just the modified part
  260. embeddings = self.embeddings_tensors.view(
  261. self.embeddings_tensors.shape[0] *
  262. self.embeddings_tensors.shape[1],
  263. self.embeddings_tensors.shape[2]
  264. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  265. assert self.embeddings_weights is not None
  266. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  267. def set_mapping(
  268. self,
  269. base_indices: torch.Tensor,
  270. sampler_indices: torch.Tensor,
  271. sampler_indices_padded: torch.Tensor,
  272. embeddings_indices: torch.Tensor,
  273. long_lora_indices: torch.Tensor,
  274. indices_len: List[int],
  275. ):
  276. self.indices = base_indices
  277. self.embeddings_indices = embeddings_indices
  278. self.indices_len = indices_len
  279. def forward(self, x: torch.Tensor) -> torch.Tensor:
  280. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  281. embedding_len = self.indices_len[3]
  282. indices = self.embeddings_indices[1][:embedding_len].view_as(x)
  283. full_lora_a_embeddings = F.embedding(
  284. x + indices,
  285. self.lora_a_stacked_2d,
  286. )
  287. indices = self.embeddings_indices[0][:embedding_len].view_as(x)
  288. full_output = self.base_layer.forward(
  289. x.add_(indices * added_tokens_mask))
  290. full_output_org = full_output
  291. if full_output.ndim == 3:
  292. full_output = full_output.view(
  293. full_output.shape[0] * full_output.shape[1], -1)
  294. if full_lora_a_embeddings.ndim == 3:
  295. full_lora_a_embeddings = full_lora_a_embeddings.view(
  296. full_lora_a_embeddings.shape[0] *
  297. full_lora_a_embeddings.shape[1], -1)
  298. bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
  299. self.indices[:self.indices_len[0]], 0, 1.0)
  300. return full_output.view_as(full_output_org)
  301. @classmethod
  302. def can_replace_layer(cls, source_layer: nn.Module,
  303. lora_config: LoRAConfig, packed_modules_list: List,
  304. model_config: Optional[PretrainedConfig]) -> bool:
  305. return type(source_layer) is VocabParallelEmbedding
  306. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  307. """
  308. LoRA on top of ColumnParallelLinear layer.
  309. LoRA B is sliced for tensor parallelism.
  310. """
  311. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  312. super().__init__()
  313. self.base_layer = base_layer
  314. self.tp_size = get_tensor_model_parallel_world_size()
  315. self.input_size = self.base_layer.input_size
  316. self.output_size = self.base_layer.output_size_per_partition
  317. self.device = _get_lora_device(self.base_layer)
  318. def create_lora_weights(
  319. self,
  320. max_loras: int,
  321. lora_config: LoRAConfig,
  322. model_config: Optional[PretrainedConfig] = None) -> None:
  323. self.lora_config = lora_config
  324. self.tp_size = get_tensor_model_parallel_world_size()
  325. lora_a_output_size_per_partition = (
  326. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  327. else divide(lora_config.max_lora_rank, self.tp_size))
  328. self.lora_a_stacked = torch.zeros(
  329. max_loras,
  330. 1,
  331. lora_a_output_size_per_partition,
  332. self.input_size,
  333. dtype=lora_config.lora_dtype,
  334. device=self.device,
  335. )
  336. self.lora_b_stacked = torch.zeros(
  337. max_loras,
  338. 1,
  339. self.output_size,
  340. lora_config.max_lora_rank,
  341. dtype=lora_config.lora_dtype,
  342. device=self.device,
  343. )
  344. self.output_dim = self.lora_b_stacked.shape[2]
  345. # lazily initialized.
  346. self.indices: torch.Tensor
  347. self.indices_len: List[int]
  348. def reset_lora(self, index: int):
  349. self.lora_a_stacked[index] = 0
  350. self.lora_b_stacked[index] = 0
  351. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  352. return lora_a
  353. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  354. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  355. shard_size = self.output_dim
  356. start_idx = tensor_model_parallel_rank * shard_size
  357. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  358. lora_b = lora_b[:, start_idx:end_idx]
  359. return lora_b
  360. def set_lora(
  361. self,
  362. index: int,
  363. lora_a: torch.Tensor,
  364. lora_b: torch.Tensor,
  365. embeddings_tensor: Optional[torch.Tensor],
  366. ):
  367. self.reset_lora(index)
  368. if self.tp_size > 1:
  369. lora_a = self.slice_lora_a(lora_a)
  370. lora_b = self.slice_lora_b(lora_b)
  371. self.lora_a_stacked[index,
  372. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  373. lora_a.T, non_blocking=True)
  374. self.lora_b_stacked[index,
  375. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  376. lora_b.T, non_blocking=True)
  377. def set_mapping(
  378. self,
  379. base_indices: torch.Tensor,
  380. sampler_indices: torch.Tensor,
  381. sampler_indices_padded: torch.Tensor,
  382. embeddings_indices: torch.Tensor,
  383. long_lora_indices: torch.Tensor,
  384. indices_len: List[int],
  385. ):
  386. self.indices = base_indices
  387. self.indices_len = indices_len
  388. def apply(self, x: torch.Tensor,
  389. bias: Optional[torch.Tensor]) -> torch.Tensor:
  390. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  391. _apply_lora(
  392. x,
  393. self.lora_a_stacked,
  394. self.lora_b_stacked,
  395. self.indices[:self.indices_len[0]],
  396. output,
  397. )
  398. return output
  399. def forward(self, input_):
  400. """Forward of ColumnParallelLinear
  401. Args:
  402. input_: Tensor whose last dimension is `input_size`.
  403. Returns:
  404. - output
  405. - bias
  406. """
  407. bias = (self.base_layer.bias
  408. if not self.base_layer.skip_bias_add else None)
  409. # Matrix multiply.
  410. output_parallel = self.apply(input_, bias)
  411. if self.base_layer.gather_output:
  412. # All-gather across the partitions.
  413. output = tensor_model_parallel_all_gather(output_parallel)
  414. else:
  415. output = output_parallel
  416. output_bias = (self.base_layer.bias
  417. if self.base_layer.skip_bias_add else None)
  418. return output, output_bias
  419. @classmethod
  420. @_not_fully_sharded_can_replace
  421. def can_replace_layer(cls, source_layer: nn.Module,
  422. lora_config: LoRAConfig, packed_modules_list: List,
  423. model_config: Optional[PretrainedConfig]) -> bool:
  424. return type(source_layer) is ColumnParallelLinear or (
  425. type(source_layer) is MergedColumnParallelLinear
  426. and len(packed_modules_list) == 1)
  427. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  428. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  429. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  430. This means we have 2 LoRAs, each applied to one half of the layer.
  431. Both slices must have the same size.
  432. """
  433. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  434. super().__init__(base_layer)
  435. def create_lora_weights(
  436. self,
  437. max_loras: int,
  438. lora_config: LoRAConfig,
  439. model_config: Optional[PretrainedConfig] = None) -> None:
  440. self.lora_config = lora_config
  441. n_slices = 2
  442. if not (len(self.base_layer.output_sizes) == n_slices
  443. and self.base_layer.output_sizes[0]
  444. == self.base_layer.output_sizes[1]):
  445. raise ValueError(
  446. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  447. "the same size.")
  448. self.tp_size = get_tensor_model_parallel_world_size()
  449. self.tp_rank = get_tensor_model_parallel_rank()
  450. lora_a_output_size_per_partition = (
  451. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  452. else divide(lora_config.max_lora_rank, self.tp_size))
  453. self.lora_a_stacked = tuple(
  454. torch.zeros(
  455. max_loras,
  456. 1,
  457. lora_a_output_size_per_partition,
  458. self.input_size,
  459. dtype=lora_config.lora_dtype,
  460. device=self.device,
  461. ) for _ in range(n_slices))
  462. self.lora_b_stacked = tuple(
  463. torch.zeros(
  464. max_loras,
  465. 1,
  466. self.output_size // 2,
  467. lora_config.max_lora_rank,
  468. dtype=lora_config.lora_dtype,
  469. device=self.device,
  470. ) for _ in range(n_slices))
  471. self.output_dim = self.lora_b_stacked[0].shape[2]
  472. # Lazily initialized.
  473. self.indices: torch.Tensor
  474. def reset_lora(self, index: int):
  475. self.lora_a_stacked[0][index] = 0
  476. self.lora_a_stacked[1][index] = 0
  477. self.lora_b_stacked[0][index] = 0
  478. self.lora_b_stacked[1][index] = 0
  479. def slice_lora_a(
  480. self, lora_a: List[Union[torch.Tensor, None]]
  481. ) -> List[Union[torch.Tensor, None]]:
  482. return lora_a
  483. def slice_lora_b(
  484. self, lora_b: List[Union[torch.Tensor, None]]
  485. ) -> List[Union[torch.Tensor, None]]:
  486. if lora_b[0] is None or lora_b[1] is None:
  487. return lora_b
  488. shard_size = self.output_dim
  489. start_idx = self.tp_rank * shard_size
  490. end_idx = (self.tp_rank + 1) * shard_size
  491. lora_b = [
  492. lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx]
  493. ]
  494. return lora_b
  495. def set_lora(
  496. self,
  497. index: int,
  498. lora_a: torch.Tensor,
  499. lora_b: torch.Tensor,
  500. embeddings_tensor: Optional[torch.Tensor],
  501. ):
  502. self.reset_lora(index)
  503. if self.tp_size > 1:
  504. lora_a = self.slice_lora_a(lora_a)
  505. lora_b = self.slice_lora_b(lora_b)
  506. if lora_a[0] is not None:
  507. self.lora_a_stacked[0][
  508. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  509. lora_a[0].T, non_blocking=True)
  510. self.lora_b_stacked[0][
  511. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  512. lora_b[0].T, non_blocking=True)
  513. if lora_a[1] is not None:
  514. self.lora_a_stacked[1][
  515. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  516. lora_a[1].T, non_blocking=True)
  517. self.lora_b_stacked[1][
  518. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  519. lora_b[1].T, non_blocking=True)
  520. def apply(self, x: torch.Tensor,
  521. bias: Optional[torch.Tensor]) -> torch.Tensor:
  522. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  523. _apply_lora_packed_nslice(
  524. x,
  525. self.lora_a_stacked,
  526. self.lora_b_stacked,
  527. self.indices[:self.indices_len[0]],
  528. output,
  529. (self.output_dim, self.output_dim),
  530. )
  531. return output
  532. @classmethod
  533. @_not_fully_sharded_can_replace
  534. def can_replace_layer(cls, source_layer: nn.Module,
  535. lora_config: LoRAConfig, packed_modules_list: List,
  536. model_config: Optional[PretrainedConfig]) -> bool:
  537. return type(source_layer) is MergedColumnParallelLinear and len(
  538. packed_modules_list) == 2
  539. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  540. """
  541. ColumnParallelLinear layer that is specifically designed for
  542. qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
  543. only contains a single LoRA within their qkv_proj layer.
  544. During inference with Tensor Parallel, the weights of lora_b
  545. must be accurately partitioned according to the respective ranks.
  546. Q slice may have different shape than K and V slices (which both have
  547. the same shape).
  548. """
  549. def __init__(self, base_layer: QKVParallelLinear) -> None:
  550. super().__init__(base_layer)
  551. self.tp_size = get_tensor_model_parallel_world_size()
  552. self.q_proj_total_size = (self.base_layer.total_num_heads *
  553. self.base_layer.head_size)
  554. self.q_proj_shard_size = (self.base_layer.num_heads *
  555. self.base_layer.head_size)
  556. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  557. self.base_layer.head_size)
  558. self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
  559. self.base_layer.head_size)
  560. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  561. tp_rank = get_tensor_model_parallel_rank()
  562. self.q_shard_id = tp_rank
  563. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  564. lora_b_q = lora_b[:, self.q_proj_shard_size *
  565. self.q_shard_id:self.q_proj_shard_size *
  566. (self.q_shard_id + 1)]
  567. k_offset = self.q_proj_total_size
  568. lora_b_k = lora_b[:, k_offset +
  569. self.kv_proj_shard_size * self.kv_shard_id:k_offset +
  570. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  571. v_offset = k_offset + self.kv_proj_total_size
  572. lora_b_v = lora_b[:, v_offset +
  573. self.kv_proj_shard_size * self.kv_shard_id:v_offset +
  574. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  575. lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
  576. return lora_b
  577. def set_lora(
  578. self,
  579. index: int,
  580. lora_a: torch.Tensor,
  581. lora_b: torch.Tensor,
  582. embeddings_tensor: Optional[torch.Tensor],
  583. ):
  584. self.reset_lora(index)
  585. if self.tp_size > 1:
  586. lora_a = self.slice_lora_a(lora_a)
  587. lora_b = self.slice_lora_b(lora_b)
  588. self.lora_a_stacked[index,
  589. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  590. lora_a.T, non_blocking=True)
  591. self.lora_b_stacked[index,
  592. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  593. lora_b.T, non_blocking=True)
  594. @classmethod
  595. @_not_fully_sharded_can_replace
  596. def can_replace_layer(cls, source_layer: nn.Module,
  597. lora_config: LoRAConfig, packed_modules_list: List,
  598. model_config: Optional[PretrainedConfig]) -> bool:
  599. return type(source_layer) is QKVParallelLinear and len(
  600. packed_modules_list) == 1
  601. class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  602. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  603. packed together in qkv proj fashion
  604. (q_proj + k_proj + v_proj -> qkv_proj).
  605. This means we have 3 LoRAs, each applied to one slice of the layer.
  606. Q slice may have different shape than K and V slices (which both have
  607. the same shape).
  608. """
  609. def __init__(self, base_layer: QKVParallelLinear) -> None:
  610. super().__init__(base_layer)
  611. def create_lora_weights(
  612. self,
  613. max_loras: int,
  614. lora_config: LoRAConfig,
  615. model_config: Optional[PretrainedConfig] = None) -> None:
  616. self.lora_config = lora_config
  617. self.tp_size = get_tensor_model_parallel_world_size()
  618. self.tp_rank = get_tensor_model_parallel_rank()
  619. self.q_proj_shard_size = (self.base_layer.num_heads *
  620. self.base_layer.head_size)
  621. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  622. self.base_layer.head_size)
  623. self.q_shard_id = self.tp_rank
  624. self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
  625. lora_a_output_size_per_partition = (
  626. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  627. else divide(lora_config.max_lora_rank, self.tp_size))
  628. # q, k, v
  629. self.lora_a_stacked = (
  630. torch.zeros(
  631. max_loras,
  632. 1,
  633. lora_a_output_size_per_partition,
  634. self.input_size,
  635. dtype=lora_config.lora_dtype,
  636. device=self.device,
  637. ),
  638. torch.zeros(
  639. max_loras,
  640. 1,
  641. lora_a_output_size_per_partition,
  642. self.input_size,
  643. dtype=lora_config.lora_dtype,
  644. device=self.device,
  645. ),
  646. torch.zeros(
  647. max_loras,
  648. 1,
  649. lora_a_output_size_per_partition,
  650. self.input_size,
  651. dtype=lora_config.lora_dtype,
  652. device=self.device,
  653. ),
  654. )
  655. self.lora_b_stacked = (
  656. torch.zeros(
  657. max_loras,
  658. 1,
  659. self.q_proj_shard_size,
  660. lora_config.max_lora_rank,
  661. dtype=lora_config.lora_dtype,
  662. device=self.device,
  663. ),
  664. torch.zeros(
  665. max_loras,
  666. 1,
  667. self.kv_proj_shard_size,
  668. lora_config.max_lora_rank,
  669. dtype=lora_config.lora_dtype,
  670. device=self.device,
  671. ),
  672. torch.zeros(
  673. max_loras,
  674. 1,
  675. self.kv_proj_shard_size,
  676. lora_config.max_lora_rank,
  677. dtype=lora_config.lora_dtype,
  678. device=self.device,
  679. ),
  680. )
  681. self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
  682. self.kv_proj_shard_size)
  683. self.packed_indices: Optional[torch.Tensor] = None
  684. self.standard_indices: Optional[torch.Tensor] = None
  685. # lazily initialized.
  686. self.indices_len: List[int]
  687. def reset_lora(self, index: int):
  688. self.lora_a_stacked[0][index] = 0
  689. self.lora_b_stacked[0][index] = 0
  690. self.lora_a_stacked[1][index] = 0
  691. self.lora_b_stacked[1][index] = 0
  692. self.lora_a_stacked[2][index] = 0
  693. self.lora_b_stacked[2][index] = 0
  694. def slice_lora_a(
  695. self, lora_a: List[Union[torch.Tensor, None]]
  696. ) -> List[Union[torch.Tensor, None]]:
  697. return lora_a
  698. def slice_lora_b(
  699. self, lora_b: List[Union[torch.Tensor, None]]
  700. ) -> List[Union[torch.Tensor, None]]:
  701. lora_b_q, lora_b_k, lora_b_v = None, None, None
  702. if lora_b[0] is not None:
  703. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  704. self.q_shard_id:self.q_proj_shard_size *
  705. (self.q_shard_id + 1)]
  706. if lora_b[1] is not None:
  707. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  708. self.kv_shard_id:self.kv_proj_shard_size *
  709. (self.kv_shard_id + 1)]
  710. if lora_b[2] is not None:
  711. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  712. self.kv_shard_id:self.kv_proj_shard_size *
  713. (self.kv_shard_id + 1)]
  714. lora_b = [lora_b_q, lora_b_k, lora_b_v]
  715. return lora_b
  716. def set_lora(
  717. self,
  718. index: int,
  719. lora_a: torch.Tensor,
  720. lora_b: torch.Tensor,
  721. embeddings_tensor: Optional[torch.Tensor],
  722. ):
  723. self.reset_lora(index)
  724. if self.tp_size > 1:
  725. lora_a = self.slice_lora_a(lora_a)
  726. lora_b = self.slice_lora_b(lora_b)
  727. if lora_b[0] is not None:
  728. lora_b_q = lora_b[0]
  729. self.lora_b_stacked[0][
  730. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  731. lora_b_q.T, non_blocking=True)
  732. if lora_b[1] is not None:
  733. lora_b_k = lora_b[1]
  734. self.lora_b_stacked[1][
  735. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  736. lora_b_k.T, non_blocking=True)
  737. if lora_b[2] is not None:
  738. lora_b_v = lora_b[2]
  739. self.lora_b_stacked[2][
  740. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  741. lora_b_v.T, non_blocking=True)
  742. if lora_a[0] is not None:
  743. self.lora_a_stacked[0][
  744. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  745. lora_a[0].T, non_blocking=True)
  746. if lora_a[1] is not None:
  747. self.lora_a_stacked[1][
  748. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  749. lora_a[1].T, non_blocking=True)
  750. if lora_a[2] is not None:
  751. self.lora_a_stacked[2][
  752. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  753. lora_a[2].T, non_blocking=True)
  754. def apply(self, x: torch.Tensor,
  755. bias: Optional[torch.Tensor]) -> torch.Tensor:
  756. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  757. _apply_lora_packed_nslice(
  758. x,
  759. self.lora_a_stacked,
  760. self.lora_b_stacked,
  761. self.indices[:self.indices_len[0]],
  762. output,
  763. self.output_slices,
  764. )
  765. return output
  766. @classmethod
  767. @_not_fully_sharded_can_replace
  768. def can_replace_layer(cls, source_layer: nn.Module,
  769. lora_config: LoRAConfig, packed_modules_list: List,
  770. model_config: Optional[PretrainedConfig]) -> bool:
  771. return type(source_layer) is QKVParallelLinear and len(
  772. packed_modules_list) == 3
  773. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  774. def __init__(self, base_layer: RowParallelLinear) -> None:
  775. super().__init__()
  776. self.base_layer = base_layer
  777. self.input_size = self.base_layer.input_size_per_partition
  778. self.output_size = self.base_layer.output_size
  779. self.device = _get_lora_device(self.base_layer)
  780. def create_lora_weights(
  781. self,
  782. max_loras: int,
  783. lora_config: LoRAConfig,
  784. model_config: Optional[PretrainedConfig] = None) -> None:
  785. self.lora_config = lora_config
  786. self.tp_rank = get_tensor_model_parallel_rank()
  787. self.lora_a_stacked = torch.zeros(
  788. (
  789. max_loras,
  790. 1,
  791. lora_config.max_lora_rank,
  792. self.input_size,
  793. ),
  794. dtype=lora_config.lora_dtype,
  795. device=self.device,
  796. )
  797. tp_size = get_tensor_model_parallel_world_size()
  798. lora_b_output_size_per_partition = (
  799. self.output_size if not lora_config.fully_sharded_loras else
  800. divide(self.output_size, tp_size))
  801. self.lora_b_stacked = torch.zeros(
  802. (
  803. max_loras,
  804. 1,
  805. lora_b_output_size_per_partition,
  806. lora_config.max_lora_rank,
  807. ),
  808. dtype=lora_config.lora_dtype,
  809. device=self.device,
  810. )
  811. # Lazily initialized
  812. self.indices: torch.Tensor
  813. self.indices_len: List[int]
  814. def reset_lora(self, index: int):
  815. self.lora_a_stacked[index] = 0
  816. self.lora_b_stacked[index] = 0
  817. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  818. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  819. shard_size = self.input_size
  820. start_idx = tensor_model_parallel_rank * shard_size
  821. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  822. lora_a = lora_a[start_idx:end_idx, :]
  823. return lora_a
  824. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  825. return lora_b
  826. def set_lora(
  827. self,
  828. index: int,
  829. lora_a: torch.Tensor,
  830. lora_b: torch.Tensor,
  831. embeddings_tensor: Optional[torch.Tensor],
  832. ):
  833. self.reset_lora(index)
  834. if self.base_layer.tp_size > 1:
  835. lora_a = self.slice_lora_a(lora_a)
  836. lora_b = self.slice_lora_b(lora_b)
  837. self.lora_a_stacked[index,
  838. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  839. lora_a.T, non_blocking=True)
  840. self.lora_b_stacked[index,
  841. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  842. lora_b.T, non_blocking=True)
  843. def set_mapping(
  844. self,
  845. base_indices: torch.Tensor,
  846. sampler_indices: torch.Tensor,
  847. sampler_indices_padded: torch.Tensor,
  848. embeddings_indices: torch.Tensor,
  849. long_lora_indices: torch.Tensor,
  850. indices_len: List[int],
  851. ):
  852. self.indices = base_indices
  853. self.indices_len = indices_len
  854. def apply(self, x: torch.Tensor) -> torch.Tensor:
  855. output = self.base_layer.quant_method.apply(self.base_layer, x)
  856. _apply_lora(
  857. x,
  858. self.lora_a_stacked,
  859. self.lora_b_stacked,
  860. self.indices[:self.indices_len[0]],
  861. output,
  862. )
  863. return output
  864. def forward(self, input_):
  865. """Forward of RowParallelLinear
  866. Args:
  867. input_: tensor whose last dimension is `input_size`. If
  868. `input_is_parallel` is set, then the last dimension
  869. is `input_size // tp_size`.
  870. Returns:
  871. - output
  872. - bias
  873. """
  874. # Set up backprop all-reduce.
  875. if self.base_layer.input_is_parallel:
  876. input_parallel = input_
  877. else:
  878. # TODO: simplify code below
  879. tp_rank = get_tensor_model_parallel_rank()
  880. splitted_input = split_tensor_along_last_dim(
  881. input_, num_partitions=self.base_layer.tp_size)
  882. input_parallel = splitted_input[tp_rank].contiguous()
  883. # Matrix multiply.
  884. output_parallel = self.apply(input_parallel)
  885. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  886. output_ = tensor_model_parallel_all_reduce(output_parallel)
  887. else:
  888. output_ = output_parallel
  889. if not self.base_layer.skip_bias_add:
  890. output = (output_ + self.base_layer.bias
  891. if self.base_layer.bias is not None else output_)
  892. output_bias = None
  893. else:
  894. output = output_
  895. output_bias = self.base_layer.bias
  896. return output, output_bias
  897. @property
  898. def weight(self):
  899. return self.base_layer.weight if hasattr(
  900. self.base_layer, "weight") else self.base_layer.qweight
  901. @classmethod
  902. @_not_fully_sharded_can_replace
  903. def can_replace_layer(cls, source_layer: nn.Module,
  904. lora_config: LoRAConfig, packed_modules_list: List,
  905. model_config: Optional[PretrainedConfig]) -> bool:
  906. return type(source_layer) is RowParallelLinear
  907. class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
  908. """
  909. LoRA wrapper for LogitsProcessor, with extra logic to handle the
  910. application of the LoRA adapter and added LoRA vocabulary.
  911. Args:
  912. base_layer: LogitsProcessor layer
  913. hidden_size: hidden size of the model
  914. dtype: data type of the model
  915. device: device of the model
  916. sharded_to_full_mapping: index mapping from sharded vocab to full vocab
  917. received from base_layer.get_sharded_to_full_mapping(). If None,
  918. no reindexing will be done.
  919. """
  920. def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
  921. dtype: torch.dtype, device: torch.device,
  922. sharded_to_full_mapping: Optional[List[int]]) -> None:
  923. super().__init__()
  924. self.base_layer = base_layer
  925. self.hidden_size = hidden_size
  926. self.dtype = dtype
  927. self.device = device
  928. self.tp_size = get_tensor_model_parallel_world_size()
  929. self.tp_rank = get_tensor_model_parallel_rank()
  930. self.sharded_to_full_mapping = sharded_to_full_mapping
  931. @property
  932. def logits_as_input(self):
  933. return self.base_layer.logits_as_input
  934. @property
  935. def vocab_size(self):
  936. return self.base_layer.vocab_size
  937. @property
  938. def scale(self):
  939. return self.base_layer.scale
  940. @property
  941. def org_vocab_size(self):
  942. return self.base_layer.org_vocab_size
  943. @property
  944. def include_gpu_probs_tensor(self):
  945. return self.base_layer.include_gpu_probs_tensor
  946. def create_lora_weights(
  947. self,
  948. max_loras: int,
  949. lora_config: LoRAConfig,
  950. model_config: Optional[PretrainedConfig] = None,
  951. ) -> None:
  952. # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
  953. if 32000 < self.base_layer.vocab_size > 128512:
  954. raise ValueError("When using LoRA, vocab size must be "
  955. "32000 >= vocab_size <= 128512")
  956. self.lora_a_stacked = torch.zeros(
  957. (
  958. max_loras,
  959. 1,
  960. lora_config.max_lora_rank,
  961. self.hidden_size,
  962. ),
  963. dtype=lora_config.lora_dtype,
  964. device=self.device,
  965. )
  966. self.lora_b_stacked = torch.zeros(
  967. (
  968. max_loras,
  969. 1,
  970. # Pad for kernel compatibility
  971. math.ceil(self.base_layer.vocab_size /
  972. lora_config.lora_vocab_padding_size) *
  973. lora_config.lora_vocab_padding_size,
  974. lora_config.max_lora_rank,
  975. ),
  976. dtype=lora_config.lora_dtype,
  977. device=self.device,
  978. )
  979. self.embeddings_tensors = torch.full(
  980. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  981. fill_value=float("-inf"),
  982. dtype=self.dtype,
  983. device=self.device,
  984. )
  985. if self.sharded_to_full_mapping is not None:
  986. self.sharded_to_full_mapping_gpu = torch.tensor(
  987. self.sharded_to_full_mapping,
  988. device=self.device,
  989. dtype=torch.long)
  990. else:
  991. self.sharded_to_full_mapping_gpu = None
  992. # Lazily initialized.
  993. self.indices: torch.Tensor
  994. self.indices_len: List[int]
  995. self.indices_padded: torch.Tensor
  996. def reset_lora(self, index: int):
  997. self.lora_a_stacked[index] = 0
  998. self.lora_b_stacked[index] = 0
  999. self.embeddings_tensors[index] = float("-inf")
  1000. def set_lora(
  1001. self,
  1002. index: int,
  1003. lora_a: torch.Tensor,
  1004. lora_b: torch.Tensor,
  1005. embeddings_tensor: Optional[torch.Tensor],
  1006. ):
  1007. self.reset_lora(index)
  1008. self.lora_a_stacked[index,
  1009. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  1010. lora_a.T, non_blocking=True)
  1011. self.lora_b_stacked[index,
  1012. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  1013. lora_b.T, non_blocking=True)
  1014. if embeddings_tensor is not None:
  1015. self.embeddings_tensors[
  1016. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  1017. shape[1], ] = embeddings_tensor
  1018. def set_mapping(
  1019. self,
  1020. base_indices: torch.Tensor,
  1021. sampler_indices: torch.Tensor,
  1022. sampler_indices_padded: torch.Tensor,
  1023. embeddings_indices: torch.Tensor,
  1024. long_lora_indices: torch.Tensor,
  1025. indices_len: List[int],
  1026. ):
  1027. self.indices = sampler_indices
  1028. self.indices_padded = sampler_indices_padded
  1029. self.indices_len = indices_len
  1030. def _get_logits(
  1031. self,
  1032. hidden_states: torch.Tensor,
  1033. lm_head: VocabParallelEmbedding,
  1034. embedding_bias: Optional[torch.Tensor] = None,
  1035. ) -> Optional[torch.Tensor]:
  1036. # Get the logits for the next tokens.
  1037. logits = lm_head.linear_method.apply(lm_head, hidden_states)
  1038. if embedding_bias is not None:
  1039. logits += embedding_bias
  1040. logits = tensor_model_parallel_gather(logits)
  1041. if logits is None:
  1042. return None
  1043. if self.sharded_to_full_mapping_gpu is not None:
  1044. # Reindex full logits tensor to ensure 1:1 mapping between
  1045. # index and token_id
  1046. # Example for:
  1047. # org_vocab_size = 4
  1048. # added_vocab_size = 2
  1049. # pad_to_size = 8
  1050. # tp_size = 2
  1051. # indices: [0, 1, 2, 3, 4, 5, 6, 7]
  1052. # token_id: [0, 1, 4, -1, 2, 3, 5, -1]
  1053. # Therefore, the mapping is expected to be:
  1054. # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
  1055. # we get:
  1056. # indices: [0, 1, 2, 3, 4, 5, 6, 7]
  1057. # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
  1058. logits = logits[:, self.sharded_to_full_mapping_gpu]
  1059. lora_logits = torch.empty(
  1060. self.embeddings_tensors.shape[0] + 1,
  1061. self.embeddings_tensors.shape[1],
  1062. hidden_states.shape[0],
  1063. dtype=self.embeddings_tensors.dtype,
  1064. device=self.embeddings_tensors.device,
  1065. )
  1066. torch.matmul(self.embeddings_tensors,
  1067. hidden_states.T,
  1068. out=lora_logits[:-1])
  1069. lora_logits[-1] = float("-inf")
  1070. lora_logits = lora_logits.mT
  1071. lora_logits = (lora_logits.reshape(
  1072. lora_logits.shape[0] * lora_logits.shape[1],
  1073. lora_logits.shape[2],
  1074. ).index_select(0,
  1075. self.indices_padded[:self.indices_len[2]]).nan_to_num_(
  1076. nan=float("-inf"),
  1077. posinf=float("inf"),
  1078. neginf=float("-inf")))
  1079. logits[:,
  1080. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  1081. lora_logits.shape[1]] = lora_logits
  1082. _apply_lora(
  1083. hidden_states,
  1084. self.lora_a_stacked,
  1085. self.lora_b_stacked,
  1086. self.indices[:self.indices_len[1]],
  1087. logits,
  1088. )
  1089. # Remove paddings in vocab (if any).
  1090. logits = logits[:, :self.base_layer.vocab_size]
  1091. return logits
  1092. def forward(self, *args, **kwargs):
  1093. return type(self.base_layer).forward(self, *args, **kwargs)
  1094. @classmethod
  1095. def can_replace_layer(cls, source_layer: nn.Module,
  1096. lora_config: LoRAConfig, packed_modules_list: List,
  1097. model_config: Optional[PretrainedConfig]) -> bool:
  1098. # Special handling for the LogitsProcessor.
  1099. return False
  1100. class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
  1101. """Implements RoPE-scaled embeddings with linear scaling for
  1102. multiple LoRA adapters with a specialized kernel.
  1103. Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
  1104. which can handle multi lora adapters in a specialied kernel.
  1105. """
  1106. def __init__(self, base_layer: RotaryEmbedding) -> None:
  1107. super().__init__()
  1108. self.base_layer = base_layer
  1109. # Lazily initialized
  1110. self.long_lora_indices: torch.Tensor
  1111. self.indices_len: List[int]
  1112. @property
  1113. def scaling_factors(self):
  1114. return self.base_layer.scaling_factors
  1115. @property
  1116. def rotary_dim(self):
  1117. return self.base_layer.rotary_dim
  1118. def create_lora_weights(
  1119. self,
  1120. max_loras: int,
  1121. lora_config: LoRAConfig,
  1122. model_config: Optional[PretrainedConfig] = None,
  1123. ) -> None:
  1124. scaling_factors = list(
  1125. lora_config.long_lora_scaling_factors
  1126. ) if lora_config.long_lora_scaling_factors else []
  1127. base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
  1128. self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
  1129. scaling_factors = sorted(
  1130. list(set([base_scaling_factor] + scaling_factors)))
  1131. self.base_layer = LinearScalingRotaryEmbedding(
  1132. self.base_layer.head_size,
  1133. self.base_layer.rotary_dim,
  1134. self.base_layer.max_position_embeddings,
  1135. self.base_layer.base,
  1136. self.base_layer.is_neox_style,
  1137. scaling_factors,
  1138. self.base_layer.dtype,
  1139. )
  1140. def reset_lora(self, index: int):
  1141. ...
  1142. def set_lora(
  1143. self,
  1144. index: int,
  1145. lora_a: torch.Tensor,
  1146. lora_b: torch.Tensor,
  1147. embeddings_tensor: Optional[torch.Tensor],
  1148. ):
  1149. ...
  1150. def set_mapping(
  1151. self,
  1152. base_indices: torch.Tensor,
  1153. sampler_indices: torch.Tensor,
  1154. sampler_indices_padded: torch.Tensor,
  1155. embeddings_indices: torch.Tensor,
  1156. long_lora_indices: torch.Tensor,
  1157. indices_len: List[int],
  1158. ):
  1159. self.long_lora_indices = long_lora_indices
  1160. self.indices_len = indices_len
  1161. def forward(
  1162. self,
  1163. positions: torch.Tensor,
  1164. query: torch.Tensor,
  1165. key: torch.Tensor,
  1166. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1167. return self.base_layer(
  1168. positions,
  1169. query,
  1170. key,
  1171. offsets=self.long_lora_indices[:self.indices_len[4]])
  1172. @property
  1173. def scaling_factor_to_offset(self) -> Dict[float, int]:
  1174. return self.base_layer.scaling_factor_to_offset
  1175. @classmethod
  1176. def can_replace_layer(cls, source_layer: nn.Module,
  1177. lora_config: LoRAConfig, packed_modules_list: List,
  1178. model_config: Optional[PretrainedConfig]) -> bool:
  1179. """Returns True if the layer can be replaced by this LoRA layer."""
  1180. return type(source_layer) is LinearScalingRotaryEmbedding or type(
  1181. source_layer) is RotaryEmbedding
  1182. def extra_repr(self) -> str:
  1183. return self.base_layer.extra_repr()