layers.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341
  1. # pylint: disable=unused-argument
  2. import math
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from transformers import PretrainedConfig
  9. from aphrodite.adapter_commons.layers import AdapterMapping
  10. from aphrodite.common.config import LoRAConfig
  11. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size,
  13. split_tensor_along_last_dim,
  14. tensor_model_parallel_all_gather,
  15. tensor_model_parallel_all_reduce,
  16. tensor_model_parallel_gather)
  17. from aphrodite.distributed.utils import divide
  18. from aphrodite.lora.punica import add_lora, add_lora_slice, bgmv
  19. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  20. MergedColumnParallelLinear,
  21. QKVParallelLinear,
  22. RowParallelLinear)
  23. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  24. from aphrodite.modeling.layers.rotary_embedding import (
  25. LinearScalingRotaryEmbedding, RotaryEmbedding)
  26. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  27. VocabParallelEmbedding
  28. if TYPE_CHECKING:
  29. pass
  30. def _get_lora_device(base_layer: nn.Module) -> torch.device:
  31. # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
  32. """Returns the device for where to place the LoRA tensors."""
  33. # unquantizedLinear
  34. if hasattr(base_layer, "weight"):
  35. return base_layer.weight.device
  36. # GPTQ/AWQ/SqueezeLLM
  37. elif hasattr(base_layer, "qweight"):
  38. return base_layer.qweight.device
  39. # marlin
  40. elif hasattr(base_layer, "B"):
  41. return base_layer.B.device
  42. else:
  43. raise ValueError(f"Unsupported base layer: {base_layer}")
  44. def _not_fully_sharded_can_replace(can_replace):
  45. """
  46. decorator which adds the condition of not using fully sharded loras
  47. intended to wrap can_replace_layer()
  48. """
  49. def dec(*args, **kwargs):
  50. decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True
  51. condition = (not kwargs['lora_config'].fully_sharded_loras
  52. if decorate else True)
  53. return can_replace(*args, **kwargs) and condition
  54. return dec
  55. def _apply_lora(
  56. x: torch.Tensor,
  57. lora_a_stacked: torch.Tensor,
  58. lora_b_stacked: torch.Tensor,
  59. indices: torch.Tensor,
  60. output: torch.Tensor,
  61. ):
  62. """Applies lora to each input.
  63. This method applies all loras to each input. It uses the
  64. indices vector to determine which lora yields the
  65. correct output. An index of -1 means no lora should be
  66. applied. This method adds the final lora results to the
  67. output.
  68. Input shapes:
  69. x: (batch_size, hidden_dim)
  70. lora_a_stacked: (num_loras, lora_rank, hidden_dim)
  71. lora_b_stacked: (num_loras, output_dim, lora_rank)
  72. indices: (batch_size)
  73. output: (batch_size, output_dim)
  74. """
  75. org_output = output
  76. x = x.view(-1, x.shape[-1])
  77. output = output.view(-1, output.shape[-1])
  78. indices = indices.view(-1)
  79. add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
  80. return output.view_as(org_output)
  81. def _apply_lora_packed_nslice(
  82. x: torch.Tensor,
  83. lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  84. lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  85. indices: torch.Tensor,
  86. output: torch.Tensor,
  87. output_slices: Tuple[int, ...],
  88. ):
  89. """Applies lora to each input.
  90. This method applies all loras to each input. It uses the
  91. indices vector to determine which lora yields the
  92. correct output. An index of -1 means no lora should be
  93. applied. This method adds the final lora results to the
  94. output.
  95. This method is used for layers that are composed of multiple sublayers
  96. (slices) packed together.
  97. Input shapes:
  98. x: (batch_size, hidden_dim)
  99. lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
  100. lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
  101. indices: (batch_size)
  102. output: (batch_size, q_slice_size + 2*kv_slice_size)
  103. output_slices: n-1 element tuple of (slice_size...),
  104. where n is number of slices
  105. """
  106. org_output = output
  107. x = x.view(-1, x.shape[-1])
  108. output = output.view(-1, output.shape[-1])
  109. indices = indices.view(-1)
  110. offset_left = 0
  111. for slice_idx in range(len(output_slices)):
  112. add_lora_slice(output, x, lora_a_stacked[slice_idx],
  113. lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
  114. output_slices[slice_idx])
  115. offset_left += output_slices[slice_idx]
  116. return output.view_as(org_output)
  117. @dataclass
  118. class LoRAMapping(AdapterMapping):
  119. pass
  120. class BaseLayerWithLoRA(nn.Module):
  121. def slice_lora_a(
  122. self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  123. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  124. """Slice lora a if splitting for tensor parallelism."""
  125. ...
  126. def slice_lora_b(
  127. self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
  128. ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
  129. """Slice lora b if splitting with tensor parallelism."""
  130. ...
  131. def create_lora_weights(
  132. self,
  133. max_loras: int,
  134. lora_config: LoRAConfig,
  135. model_config: Optional[PretrainedConfig] = None) -> None:
  136. """Initializes lora matrices."""
  137. ...
  138. def reset_lora(self, index: int):
  139. """Resets the lora weights at index back to 0."""
  140. ...
  141. def set_lora(
  142. self,
  143. index: int,
  144. lora_a: torch.Tensor,
  145. lora_b: torch.Tensor,
  146. embeddings_tensor: Optional[torch.Tensor],
  147. ):
  148. """Overwrites lora tensors at index."""
  149. ...
  150. def set_mapping(
  151. self,
  152. base_indices: torch.Tensor,
  153. sampler_indices: torch.Tensor,
  154. sampler_indices_padded: torch.Tensor,
  155. embeddings_indices: torch.Tensor,
  156. long_lora_indices: torch.Tensor,
  157. indices_len: List[int],
  158. ):
  159. """Sets the mapping indices."""
  160. ...
  161. @classmethod
  162. def can_replace_layer(cls, source_layer: nn.Module,
  163. lora_config: LoRAConfig, packed_modules_list: List,
  164. model_config: Optional[PretrainedConfig]) -> bool:
  165. """Returns True if the layer can be replaced by this LoRA layer."""
  166. raise NotImplementedError
  167. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
  168. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  169. super().__init__()
  170. self.base_layer = base_layer
  171. self.embeddings_slice: Optional[Tuple[int, int]]
  172. self.embeddings_weights: Optional[torch.Tensor]
  173. def create_lora_weights(
  174. self,
  175. max_loras: int,
  176. lora_config: LoRAConfig,
  177. model_config: Optional[PretrainedConfig] = None) -> None:
  178. if self.base_layer.num_added_embeddings_per_partition > 0:
  179. # We can start adding lora weights
  180. self.embeddings_weights = self.base_layer.weight.data[
  181. self.base_layer.num_org_embeddings_per_partition:self.
  182. base_layer.num_org_embeddings_per_partition +
  183. self.base_layer.num_added_embeddings_per_partition]
  184. self.embeddings_slice = (
  185. self.base_layer.shard_indices.added_vocab_start_index -
  186. self.base_layer.org_vocab_size,
  187. self.base_layer.shard_indices.added_vocab_end_index -
  188. self.base_layer.org_vocab_size)
  189. self.base_layer.weight.data[
  190. self.base_layer.num_org_embeddings_per_partition:].fill_(0)
  191. else:
  192. self.embeddings_slice = None
  193. self.embeddings_weights = None
  194. self.embeddings_tensors = torch.zeros(
  195. (
  196. max_loras,
  197. lora_config.lora_extra_vocab_size,
  198. self.base_layer.embedding_dim,
  199. ),
  200. dtype=self.base_layer.weight.dtype,
  201. device=self.base_layer.weight.device,
  202. )
  203. self.lora_a_stacked = torch.zeros(
  204. (
  205. max_loras,
  206. self.base_layer.org_vocab_size +
  207. lora_config.lora_extra_vocab_size,
  208. lora_config.max_lora_rank,
  209. ),
  210. dtype=lora_config.lora_dtype,
  211. device=self.base_layer.weight.device,
  212. )
  213. self.lora_b_stacked = torch.zeros(
  214. (
  215. max_loras,
  216. 1,
  217. self.base_layer.embedding_dim,
  218. lora_config.max_lora_rank,
  219. ),
  220. dtype=lora_config.lora_dtype,
  221. device=self.base_layer.weight.device,
  222. )
  223. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  224. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  225. self.lora_a_stacked.shape[2],
  226. )
  227. # Lazily initialized.
  228. self.indices: torch.Tensor
  229. self.indices_len: List[int]
  230. self.embeddings_indices: torch.Tensor
  231. def reset_lora(self, index: int):
  232. self.lora_a_stacked[index] = 0
  233. self.lora_b_stacked[index] = 0
  234. self.embeddings_tensors[index] = 0
  235. def set_lora(
  236. self,
  237. index: int,
  238. lora_a: torch.Tensor,
  239. lora_b: torch.Tensor,
  240. embeddings_tensor: Optional[torch.Tensor],
  241. ):
  242. self.reset_lora(index)
  243. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  244. lora_a, non_blocking=True)
  245. self.lora_b_stacked[index,
  246. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  247. lora_b.T, non_blocking=True)
  248. if embeddings_tensor is not None:
  249. self.embeddings_tensors[
  250. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  251. shape[1]].copy_(embeddings_tensor, non_blocking=True)
  252. if self.embeddings_slice is not None:
  253. # TODO(yard1): Optimize this copy, we don't need to copy
  254. # everything, just the modified part
  255. embeddings = self.embeddings_tensors.view(
  256. self.embeddings_tensors.shape[0] *
  257. self.embeddings_tensors.shape[1],
  258. self.embeddings_tensors.shape[2]
  259. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  260. assert self.embeddings_weights is not None
  261. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  262. def set_mapping(
  263. self,
  264. base_indices: torch.Tensor,
  265. sampler_indices: torch.Tensor,
  266. sampler_indices_padded: torch.Tensor,
  267. embeddings_indices: torch.Tensor,
  268. long_lora_indices: torch.Tensor,
  269. indices_len: List[int],
  270. ):
  271. self.indices = base_indices
  272. self.embeddings_indices = embeddings_indices
  273. self.indices_len = indices_len
  274. def forward(self, x: torch.Tensor) -> torch.Tensor:
  275. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  276. embedding_len = self.indices_len[3]
  277. indices = self.embeddings_indices[1][:embedding_len].view_as(x)
  278. full_lora_a_embeddings = F.embedding(
  279. x + indices,
  280. self.lora_a_stacked_2d,
  281. )
  282. indices = self.embeddings_indices[0][:embedding_len].view_as(x)
  283. full_output = self.base_layer.forward(
  284. x.add_(indices * added_tokens_mask))
  285. full_output_org = full_output
  286. if full_output.ndim == 3:
  287. full_output = full_output.view(
  288. full_output.shape[0] * full_output.shape[1], -1)
  289. if full_lora_a_embeddings.ndim == 3:
  290. full_lora_a_embeddings = full_lora_a_embeddings.view(
  291. full_lora_a_embeddings.shape[0] *
  292. full_lora_a_embeddings.shape[1], -1)
  293. bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
  294. self.indices[:self.indices_len[0]], 0, 1.0)
  295. return full_output.view_as(full_output_org)
  296. @classmethod
  297. def can_replace_layer(cls, source_layer: nn.Module,
  298. lora_config: LoRAConfig, packed_modules_list: List,
  299. model_config: Optional[PretrainedConfig]) -> bool:
  300. return type(source_layer) is VocabParallelEmbedding
  301. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  302. """
  303. LoRA on top of ColumnParallelLinear layer.
  304. LoRA B is sliced for tensor parallelism.
  305. """
  306. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  307. super().__init__()
  308. self.base_layer = base_layer
  309. self.tp_size = get_tensor_model_parallel_world_size()
  310. self.input_size = self.base_layer.input_size
  311. self.output_size = self.base_layer.output_size_per_partition
  312. self.device = _get_lora_device(self.base_layer)
  313. def create_lora_weights(
  314. self,
  315. max_loras: int,
  316. lora_config: LoRAConfig,
  317. model_config: Optional[PretrainedConfig] = None) -> None:
  318. self.lora_config = lora_config
  319. self.tp_size = get_tensor_model_parallel_world_size()
  320. lora_a_output_size_per_partition = (
  321. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  322. else divide(lora_config.max_lora_rank, self.tp_size))
  323. self.lora_a_stacked = torch.zeros(
  324. max_loras,
  325. 1,
  326. lora_a_output_size_per_partition,
  327. self.input_size,
  328. dtype=lora_config.lora_dtype,
  329. device=self.device,
  330. )
  331. self.lora_b_stacked = torch.zeros(
  332. max_loras,
  333. 1,
  334. self.output_size,
  335. lora_config.max_lora_rank,
  336. dtype=lora_config.lora_dtype,
  337. device=self.device,
  338. )
  339. self.output_dim = self.lora_b_stacked.shape[2]
  340. # lazily initialized.
  341. self.indices: torch.Tensor
  342. self.indices_len: List[int]
  343. def reset_lora(self, index: int):
  344. self.lora_a_stacked[index] = 0
  345. self.lora_b_stacked[index] = 0
  346. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  347. return lora_a
  348. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  349. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  350. shard_size = self.output_dim
  351. start_idx = tensor_model_parallel_rank * shard_size
  352. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  353. lora_b = lora_b[:, start_idx:end_idx]
  354. return lora_b
  355. def set_lora(
  356. self,
  357. index: int,
  358. lora_a: torch.Tensor,
  359. lora_b: torch.Tensor,
  360. embeddings_tensor: Optional[torch.Tensor],
  361. ):
  362. self.reset_lora(index)
  363. if self.tp_size > 1:
  364. lora_a = self.slice_lora_a(lora_a)
  365. lora_b = self.slice_lora_b(lora_b)
  366. self.lora_a_stacked[index,
  367. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  368. lora_a.T, non_blocking=True)
  369. self.lora_b_stacked[index,
  370. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  371. lora_b.T, non_blocking=True)
  372. def set_mapping(
  373. self,
  374. base_indices: torch.Tensor,
  375. sampler_indices: torch.Tensor,
  376. sampler_indices_padded: torch.Tensor,
  377. embeddings_indices: torch.Tensor,
  378. long_lora_indices: torch.Tensor,
  379. indices_len: List[int],
  380. ):
  381. self.indices = base_indices
  382. self.indices_len = indices_len
  383. def apply(self, x: torch.Tensor,
  384. bias: Optional[torch.Tensor]) -> torch.Tensor:
  385. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  386. _apply_lora(
  387. x,
  388. self.lora_a_stacked,
  389. self.lora_b_stacked,
  390. self.indices[:self.indices_len[0]],
  391. output,
  392. )
  393. return output
  394. def forward(self, input_):
  395. """Forward of ColumnParallelLinear
  396. Args:
  397. input_: Tensor whose last dimension is `input_size`.
  398. Returns:
  399. - output
  400. - bias
  401. """
  402. bias = (self.base_layer.bias
  403. if not self.base_layer.skip_bias_add else None)
  404. # Matrix multiply.
  405. output_parallel = self.apply(input_, bias)
  406. if self.base_layer.gather_output:
  407. # All-gather across the partitions.
  408. output = tensor_model_parallel_all_gather(output_parallel)
  409. else:
  410. output = output_parallel
  411. output_bias = (self.base_layer.bias
  412. if self.base_layer.skip_bias_add else None)
  413. return output, output_bias
  414. @classmethod
  415. @_not_fully_sharded_can_replace
  416. def can_replace_layer(cls, source_layer: nn.Module,
  417. lora_config: LoRAConfig, packed_modules_list: List,
  418. model_config: Optional[PretrainedConfig]) -> bool:
  419. return type(source_layer) is ColumnParallelLinear or (
  420. type(source_layer) is MergedColumnParallelLinear
  421. and len(packed_modules_list) == 1)
  422. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  423. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  424. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  425. This means we have 2 LoRAs, each applied to one half of the layer.
  426. Both slices must have the same size.
  427. """
  428. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  429. super().__init__(base_layer)
  430. def create_lora_weights(
  431. self,
  432. max_loras: int,
  433. lora_config: LoRAConfig,
  434. model_config: Optional[PretrainedConfig] = None) -> None:
  435. self.lora_config = lora_config
  436. n_slices = 2
  437. if not (len(self.base_layer.output_sizes) == n_slices
  438. and self.base_layer.output_sizes[0]
  439. == self.base_layer.output_sizes[1]):
  440. raise ValueError(
  441. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  442. "the same size.")
  443. self.tp_size = get_tensor_model_parallel_world_size()
  444. self.tp_rank = get_tensor_model_parallel_rank()
  445. lora_a_output_size_per_partition = (
  446. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  447. else divide(lora_config.max_lora_rank, self.tp_size))
  448. self.lora_a_stacked = tuple(
  449. torch.zeros(
  450. max_loras,
  451. 1,
  452. lora_a_output_size_per_partition,
  453. self.input_size,
  454. dtype=lora_config.lora_dtype,
  455. device=self.device,
  456. ) for _ in range(n_slices))
  457. self.lora_b_stacked = tuple(
  458. torch.zeros(
  459. max_loras,
  460. 1,
  461. self.output_size // 2,
  462. lora_config.max_lora_rank,
  463. dtype=lora_config.lora_dtype,
  464. device=self.device,
  465. ) for _ in range(n_slices))
  466. self.output_dim = self.lora_b_stacked[0].shape[2]
  467. # Lazily initialized.
  468. self.indices: torch.Tensor
  469. def reset_lora(self, index: int):
  470. self.lora_a_stacked[0][index] = 0
  471. self.lora_a_stacked[1][index] = 0
  472. self.lora_b_stacked[0][index] = 0
  473. self.lora_b_stacked[1][index] = 0
  474. def slice_lora_a(
  475. self, lora_a: List[Union[torch.Tensor, None]]
  476. ) -> List[Union[torch.Tensor, None]]:
  477. return lora_a
  478. def slice_lora_b(
  479. self, lora_b: List[Union[torch.Tensor, None]]
  480. ) -> List[Union[torch.Tensor, None]]:
  481. if lora_b[0] is None or lora_b[1] is None:
  482. return lora_b
  483. shard_size = self.output_dim
  484. start_idx = self.tp_rank * shard_size
  485. end_idx = (self.tp_rank + 1) * shard_size
  486. lora_b = [
  487. lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx]
  488. ]
  489. return lora_b
  490. def set_lora(
  491. self,
  492. index: int,
  493. lora_a: torch.Tensor,
  494. lora_b: torch.Tensor,
  495. embeddings_tensor: Optional[torch.Tensor],
  496. ):
  497. self.reset_lora(index)
  498. if self.tp_size > 1:
  499. lora_a = self.slice_lora_a(lora_a)
  500. lora_b = self.slice_lora_b(lora_b)
  501. if lora_a[0] is not None:
  502. self.lora_a_stacked[0][
  503. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  504. lora_a[0].T, non_blocking=True)
  505. self.lora_b_stacked[0][
  506. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  507. lora_b[0].T, non_blocking=True)
  508. if lora_a[1] is not None:
  509. self.lora_a_stacked[1][
  510. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  511. lora_a[1].T, non_blocking=True)
  512. self.lora_b_stacked[1][
  513. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  514. lora_b[1].T, non_blocking=True)
  515. def apply(self, x: torch.Tensor,
  516. bias: Optional[torch.Tensor]) -> torch.Tensor:
  517. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  518. _apply_lora_packed_nslice(
  519. x,
  520. self.lora_a_stacked,
  521. self.lora_b_stacked,
  522. self.indices[:self.indices_len[0]],
  523. output,
  524. (self.output_dim, self.output_dim),
  525. )
  526. return output
  527. @classmethod
  528. @_not_fully_sharded_can_replace
  529. def can_replace_layer(cls, source_layer: nn.Module,
  530. lora_config: LoRAConfig, packed_modules_list: List,
  531. model_config: Optional[PretrainedConfig]) -> bool:
  532. return type(source_layer) is MergedColumnParallelLinear and len(
  533. packed_modules_list) == 2
  534. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  535. """
  536. ColumnParallelLinear layer that is specifically designed for
  537. qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
  538. only contains a single LoRA within their qkv_proj layer.
  539. During inference with Tensor Parallel, the weights of lora_b
  540. must be accurately partitioned according to the respective ranks.
  541. Q slice may have different shape than K and V slices (which both have
  542. the same shape).
  543. """
  544. def __init__(self, base_layer: QKVParallelLinear) -> None:
  545. super().__init__(base_layer)
  546. self.tp_size = get_tensor_model_parallel_world_size()
  547. self.q_proj_total_size = (self.base_layer.total_num_heads *
  548. self.base_layer.head_size)
  549. self.q_proj_shard_size = (self.base_layer.num_heads *
  550. self.base_layer.head_size)
  551. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  552. self.base_layer.head_size)
  553. self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
  554. self.base_layer.head_size)
  555. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  556. tp_rank = get_tensor_model_parallel_rank()
  557. self.q_shard_id = tp_rank
  558. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  559. lora_b_q = lora_b[:, self.q_proj_shard_size *
  560. self.q_shard_id:self.q_proj_shard_size *
  561. (self.q_shard_id + 1)]
  562. k_offset = self.q_proj_total_size
  563. lora_b_k = lora_b[:, k_offset +
  564. self.kv_proj_shard_size * self.kv_shard_id:k_offset +
  565. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  566. v_offset = k_offset + self.kv_proj_total_size
  567. lora_b_v = lora_b[:, v_offset +
  568. self.kv_proj_shard_size * self.kv_shard_id:v_offset +
  569. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  570. lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
  571. return lora_b
  572. def set_lora(
  573. self,
  574. index: int,
  575. lora_a: torch.Tensor,
  576. lora_b: torch.Tensor,
  577. embeddings_tensor: Optional[torch.Tensor],
  578. ):
  579. self.reset_lora(index)
  580. if self.tp_size > 1:
  581. lora_a = self.slice_lora_a(lora_a)
  582. lora_b = self.slice_lora_b(lora_b)
  583. self.lora_a_stacked[index,
  584. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  585. lora_a.T, non_blocking=True)
  586. self.lora_b_stacked[index,
  587. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  588. lora_b.T, non_blocking=True)
  589. @classmethod
  590. @_not_fully_sharded_can_replace
  591. def can_replace_layer(cls, source_layer: nn.Module,
  592. lora_config: LoRAConfig, packed_modules_list: List,
  593. model_config: Optional[PretrainedConfig]) -> bool:
  594. return type(source_layer) is QKVParallelLinear and len(
  595. packed_modules_list) == 1
  596. class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  597. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  598. packed together in qkv proj fashion
  599. (q_proj + k_proj + v_proj -> qkv_proj).
  600. This means we have 3 LoRAs, each applied to one slice of the layer.
  601. Q slice may have different shape than K and V slices (which both have
  602. the same shape).
  603. """
  604. def __init__(self, base_layer: QKVParallelLinear) -> None:
  605. super().__init__(base_layer)
  606. def create_lora_weights(
  607. self,
  608. max_loras: int,
  609. lora_config: LoRAConfig,
  610. model_config: Optional[PretrainedConfig] = None) -> None:
  611. self.lora_config = lora_config
  612. self.tp_size = get_tensor_model_parallel_world_size()
  613. self.tp_rank = get_tensor_model_parallel_rank()
  614. self.q_proj_shard_size = (self.base_layer.num_heads *
  615. self.base_layer.head_size)
  616. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  617. self.base_layer.head_size)
  618. self.q_shard_id = self.tp_rank
  619. self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
  620. lora_a_output_size_per_partition = (
  621. lora_config.max_lora_rank if not lora_config.fully_sharded_loras
  622. else divide(lora_config.max_lora_rank, self.tp_size))
  623. # q, k, v
  624. self.lora_a_stacked = (
  625. torch.zeros(
  626. max_loras,
  627. 1,
  628. lora_a_output_size_per_partition,
  629. self.input_size,
  630. dtype=lora_config.lora_dtype,
  631. device=self.device,
  632. ),
  633. torch.zeros(
  634. max_loras,
  635. 1,
  636. lora_a_output_size_per_partition,
  637. self.input_size,
  638. dtype=lora_config.lora_dtype,
  639. device=self.device,
  640. ),
  641. torch.zeros(
  642. max_loras,
  643. 1,
  644. lora_a_output_size_per_partition,
  645. self.input_size,
  646. dtype=lora_config.lora_dtype,
  647. device=self.device,
  648. ),
  649. )
  650. self.lora_b_stacked = (
  651. torch.zeros(
  652. max_loras,
  653. 1,
  654. self.q_proj_shard_size,
  655. lora_config.max_lora_rank,
  656. dtype=lora_config.lora_dtype,
  657. device=self.device,
  658. ),
  659. torch.zeros(
  660. max_loras,
  661. 1,
  662. self.kv_proj_shard_size,
  663. lora_config.max_lora_rank,
  664. dtype=lora_config.lora_dtype,
  665. device=self.device,
  666. ),
  667. torch.zeros(
  668. max_loras,
  669. 1,
  670. self.kv_proj_shard_size,
  671. lora_config.max_lora_rank,
  672. dtype=lora_config.lora_dtype,
  673. device=self.device,
  674. ),
  675. )
  676. self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
  677. self.kv_proj_shard_size)
  678. self.packed_indices: Optional[torch.Tensor] = None
  679. self.standard_indices: Optional[torch.Tensor] = None
  680. # lazily initialized.
  681. self.indices_len: List[int]
  682. def reset_lora(self, index: int):
  683. self.lora_a_stacked[0][index] = 0
  684. self.lora_b_stacked[0][index] = 0
  685. self.lora_a_stacked[1][index] = 0
  686. self.lora_b_stacked[1][index] = 0
  687. self.lora_a_stacked[2][index] = 0
  688. self.lora_b_stacked[2][index] = 0
  689. def slice_lora_a(
  690. self, lora_a: List[Union[torch.Tensor, None]]
  691. ) -> List[Union[torch.Tensor, None]]:
  692. return lora_a
  693. def slice_lora_b(
  694. self, lora_b: List[Union[torch.Tensor, None]]
  695. ) -> List[Union[torch.Tensor, None]]:
  696. lora_b_q, lora_b_k, lora_b_v = None, None, None
  697. if lora_b[0] is not None:
  698. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  699. self.q_shard_id:self.q_proj_shard_size *
  700. (self.q_shard_id + 1)]
  701. if lora_b[1] is not None:
  702. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  703. self.kv_shard_id:self.kv_proj_shard_size *
  704. (self.kv_shard_id + 1)]
  705. if lora_b[2] is not None:
  706. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  707. self.kv_shard_id:self.kv_proj_shard_size *
  708. (self.kv_shard_id + 1)]
  709. lora_b = [lora_b_q, lora_b_k, lora_b_v]
  710. return lora_b
  711. def set_lora(
  712. self,
  713. index: int,
  714. lora_a: torch.Tensor,
  715. lora_b: torch.Tensor,
  716. embeddings_tensor: Optional[torch.Tensor],
  717. ):
  718. self.reset_lora(index)
  719. if self.tp_size > 1:
  720. lora_a = self.slice_lora_a(lora_a)
  721. lora_b = self.slice_lora_b(lora_b)
  722. if lora_b[0] is not None:
  723. lora_b_q = lora_b[0]
  724. self.lora_b_stacked[0][
  725. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  726. lora_b_q.T, non_blocking=True)
  727. if lora_b[1] is not None:
  728. lora_b_k = lora_b[1]
  729. self.lora_b_stacked[1][
  730. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  731. lora_b_k.T, non_blocking=True)
  732. if lora_b[2] is not None:
  733. lora_b_v = lora_b[2]
  734. self.lora_b_stacked[2][
  735. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  736. lora_b_v.T, non_blocking=True)
  737. if lora_a[0] is not None:
  738. self.lora_a_stacked[0][
  739. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  740. lora_a[0].T, non_blocking=True)
  741. if lora_a[1] is not None:
  742. self.lora_a_stacked[1][
  743. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  744. lora_a[1].T, non_blocking=True)
  745. if lora_a[2] is not None:
  746. self.lora_a_stacked[2][
  747. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  748. lora_a[2].T, non_blocking=True)
  749. def apply(self, x: torch.Tensor,
  750. bias: Optional[torch.Tensor]) -> torch.Tensor:
  751. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  752. _apply_lora_packed_nslice(
  753. x,
  754. self.lora_a_stacked,
  755. self.lora_b_stacked,
  756. self.indices[:self.indices_len[0]],
  757. output,
  758. self.output_slices,
  759. )
  760. return output
  761. @classmethod
  762. @_not_fully_sharded_can_replace
  763. def can_replace_layer(cls, source_layer: nn.Module,
  764. lora_config: LoRAConfig, packed_modules_list: List,
  765. model_config: Optional[PretrainedConfig]) -> bool:
  766. return type(source_layer) is QKVParallelLinear and len(
  767. packed_modules_list) == 3
  768. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  769. def __init__(self, base_layer: RowParallelLinear) -> None:
  770. super().__init__()
  771. self.base_layer = base_layer
  772. self.input_size = self.base_layer.input_size_per_partition
  773. self.output_size = self.base_layer.output_size
  774. self.device = _get_lora_device(self.base_layer)
  775. def create_lora_weights(
  776. self,
  777. max_loras: int,
  778. lora_config: LoRAConfig,
  779. model_config: Optional[PretrainedConfig] = None) -> None:
  780. self.lora_config = lora_config
  781. self.tp_rank = get_tensor_model_parallel_rank()
  782. self.lora_a_stacked = torch.zeros(
  783. (
  784. max_loras,
  785. 1,
  786. lora_config.max_lora_rank,
  787. self.input_size,
  788. ),
  789. dtype=lora_config.lora_dtype,
  790. device=self.device,
  791. )
  792. tp_size = get_tensor_model_parallel_world_size()
  793. lora_b_output_size_per_partition = (
  794. self.output_size if not lora_config.fully_sharded_loras else
  795. divide(self.output_size, tp_size))
  796. self.lora_b_stacked = torch.zeros(
  797. (
  798. max_loras,
  799. 1,
  800. lora_b_output_size_per_partition,
  801. lora_config.max_lora_rank,
  802. ),
  803. dtype=lora_config.lora_dtype,
  804. device=self.device,
  805. )
  806. # Lazily initialized
  807. self.indices: torch.Tensor
  808. self.indices_len: List[int]
  809. def reset_lora(self, index: int):
  810. self.lora_a_stacked[index] = 0
  811. self.lora_b_stacked[index] = 0
  812. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  813. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  814. shard_size = self.input_size
  815. start_idx = tensor_model_parallel_rank * shard_size
  816. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  817. lora_a = lora_a[start_idx:end_idx, :]
  818. return lora_a
  819. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  820. return lora_b
  821. def set_lora(
  822. self,
  823. index: int,
  824. lora_a: torch.Tensor,
  825. lora_b: torch.Tensor,
  826. embeddings_tensor: Optional[torch.Tensor],
  827. ):
  828. self.reset_lora(index)
  829. if self.base_layer.tp_size > 1:
  830. lora_a = self.slice_lora_a(lora_a)
  831. lora_b = self.slice_lora_b(lora_b)
  832. self.lora_a_stacked[index,
  833. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  834. lora_a.T, non_blocking=True)
  835. self.lora_b_stacked[index,
  836. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  837. lora_b.T, non_blocking=True)
  838. def set_mapping(
  839. self,
  840. base_indices: torch.Tensor,
  841. sampler_indices: torch.Tensor,
  842. sampler_indices_padded: torch.Tensor,
  843. embeddings_indices: torch.Tensor,
  844. long_lora_indices: torch.Tensor,
  845. indices_len: List[int],
  846. ):
  847. self.indices = base_indices
  848. self.indices_len = indices_len
  849. def apply(self, x: torch.Tensor) -> torch.Tensor:
  850. output = self.base_layer.quant_method.apply(self.base_layer, x)
  851. _apply_lora(
  852. x,
  853. self.lora_a_stacked,
  854. self.lora_b_stacked,
  855. self.indices[:self.indices_len[0]],
  856. output,
  857. )
  858. return output
  859. def forward(self, input_):
  860. """Forward of RowParallelLinear
  861. Args:
  862. input_: tensor whose last dimension is `input_size`. If
  863. `input_is_parallel` is set, then the last dimension
  864. is `input_size // tp_size`.
  865. Returns:
  866. - output
  867. - bias
  868. """
  869. # Set up backprop all-reduce.
  870. if self.base_layer.input_is_parallel:
  871. input_parallel = input_
  872. else:
  873. # TODO: simplify code below
  874. tp_rank = get_tensor_model_parallel_rank()
  875. splitted_input = split_tensor_along_last_dim(
  876. input_, num_partitions=self.base_layer.tp_size)
  877. input_parallel = splitted_input[tp_rank].contiguous()
  878. # Matrix multiply.
  879. output_parallel = self.apply(input_parallel)
  880. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  881. output_ = tensor_model_parallel_all_reduce(output_parallel)
  882. else:
  883. output_ = output_parallel
  884. if not self.base_layer.skip_bias_add:
  885. output = (output_ + self.base_layer.bias
  886. if self.base_layer.bias is not None else output_)
  887. output_bias = None
  888. else:
  889. output = output_
  890. output_bias = self.base_layer.bias
  891. return output, output_bias
  892. @property
  893. def weight(self):
  894. return self.base_layer.weight if hasattr(
  895. self.base_layer, "weight") else self.base_layer.qweight
  896. @classmethod
  897. @_not_fully_sharded_can_replace
  898. def can_replace_layer(cls, source_layer: nn.Module,
  899. lora_config: LoRAConfig, packed_modules_list: List,
  900. model_config: Optional[PretrainedConfig]) -> bool:
  901. return type(source_layer) is RowParallelLinear
  902. class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
  903. """
  904. LoRA wrapper for LogitsProcessor, with extra logic to handle the
  905. application of the LoRA adapter and added LoRA vocabulary.
  906. Args:
  907. base_layer: LogitsProcessor layer
  908. hidden_size: hidden size of the model
  909. dtype: data type of the model
  910. device: device of the model
  911. sharded_to_full_mapping: index mapping from sharded vocab to full vocab
  912. received from base_layer.get_sharded_to_full_mapping(). If None,
  913. no reindexing will be done.
  914. """
  915. def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
  916. dtype: torch.dtype, device: torch.device,
  917. sharded_to_full_mapping: Optional[List[int]]) -> None:
  918. super().__init__()
  919. self.base_layer = base_layer
  920. self.hidden_size = hidden_size
  921. self.dtype = dtype
  922. self.device = device
  923. self.tp_size = get_tensor_model_parallel_world_size()
  924. self.tp_rank = get_tensor_model_parallel_rank()
  925. self.sharded_to_full_mapping = sharded_to_full_mapping
  926. @property
  927. def logits_as_input(self):
  928. return self.base_layer.logits_as_input
  929. @property
  930. def vocab_size(self):
  931. return self.base_layer.vocab_size
  932. @property
  933. def scale(self):
  934. return self.base_layer.scale
  935. @property
  936. def soft_cap(self):
  937. return self.base_layer.soft_cap
  938. @property
  939. def org_vocab_size(self):
  940. return self.base_layer.org_vocab_size
  941. @property
  942. def include_gpu_probs_tensor(self):
  943. return self.base_layer.include_gpu_probs_tensor
  944. def create_lora_weights(
  945. self,
  946. max_loras: int,
  947. lora_config: LoRAConfig,
  948. model_config: Optional[PretrainedConfig] = None,
  949. ) -> None:
  950. # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
  951. if 32000 < self.base_layer.vocab_size > 128512:
  952. raise ValueError("When using LoRA, vocab size must be "
  953. "32000 >= vocab_size <= 128512")
  954. self.lora_a_stacked = torch.zeros(
  955. (
  956. max_loras,
  957. 1,
  958. lora_config.max_lora_rank,
  959. self.hidden_size,
  960. ),
  961. dtype=lora_config.lora_dtype,
  962. device=self.device,
  963. )
  964. self.lora_b_stacked = torch.zeros(
  965. (
  966. max_loras,
  967. 1,
  968. # Pad for kernel compatibility
  969. math.ceil(self.base_layer.vocab_size /
  970. lora_config.lora_vocab_padding_size) *
  971. lora_config.lora_vocab_padding_size,
  972. lora_config.max_lora_rank,
  973. ),
  974. dtype=lora_config.lora_dtype,
  975. device=self.device,
  976. )
  977. self.embeddings_tensors = torch.full(
  978. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  979. fill_value=float("-inf"),
  980. dtype=self.dtype,
  981. device=self.device,
  982. )
  983. if self.sharded_to_full_mapping is not None:
  984. self.sharded_to_full_mapping_gpu = torch.tensor(
  985. self.sharded_to_full_mapping,
  986. device=self.device,
  987. dtype=torch.long)
  988. else:
  989. self.sharded_to_full_mapping_gpu = None
  990. # Lazily initialized.
  991. self.indices: torch.Tensor
  992. self.indices_len: List[int]
  993. self.indices_padded: torch.Tensor
  994. def reset_lora(self, index: int):
  995. self.lora_a_stacked[index] = 0
  996. self.lora_b_stacked[index] = 0
  997. self.embeddings_tensors[index] = float("-inf")
  998. def set_lora(
  999. self,
  1000. index: int,
  1001. lora_a: torch.Tensor,
  1002. lora_b: torch.Tensor,
  1003. embeddings_tensor: Optional[torch.Tensor],
  1004. ):
  1005. self.reset_lora(index)
  1006. self.lora_a_stacked[index,
  1007. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  1008. lora_a.T, non_blocking=True)
  1009. self.lora_b_stacked[index,
  1010. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  1011. lora_b.T, non_blocking=True)
  1012. if embeddings_tensor is not None:
  1013. self.embeddings_tensors[
  1014. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  1015. shape[1], ] = embeddings_tensor
  1016. def set_mapping(
  1017. self,
  1018. base_indices: torch.Tensor,
  1019. sampler_indices: torch.Tensor,
  1020. sampler_indices_padded: torch.Tensor,
  1021. embeddings_indices: torch.Tensor,
  1022. long_lora_indices: torch.Tensor,
  1023. indices_len: List[int],
  1024. ):
  1025. self.indices = sampler_indices
  1026. self.indices_padded = sampler_indices_padded
  1027. self.indices_len = indices_len
  1028. def _get_logits(
  1029. self,
  1030. hidden_states: torch.Tensor,
  1031. lm_head: VocabParallelEmbedding,
  1032. embedding_bias: Optional[torch.Tensor] = None,
  1033. ) -> Optional[torch.Tensor]:
  1034. # Get the logits for the next tokens.
  1035. logits = lm_head.linear_method.apply(lm_head, hidden_states)
  1036. if embedding_bias is not None:
  1037. logits += embedding_bias
  1038. logits = tensor_model_parallel_gather(logits)
  1039. if logits is None:
  1040. return None
  1041. if self.sharded_to_full_mapping_gpu is not None:
  1042. # Reindex full logits tensor to ensure 1:1 mapping between
  1043. # index and token_id
  1044. # Example for:
  1045. # org_vocab_size = 4
  1046. # added_vocab_size = 2
  1047. # pad_to_size = 8
  1048. # tp_size = 2
  1049. # indices: [0, 1, 2, 3, 4, 5, 6, 7]
  1050. # token_id: [0, 1, 4, -1, 2, 3, 5, -1]
  1051. # Therefore, the mapping is expected to be:
  1052. # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
  1053. # we get:
  1054. # indices: [0, 1, 2, 3, 4, 5, 6, 7]
  1055. # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
  1056. logits = logits[:, self.sharded_to_full_mapping_gpu]
  1057. lora_logits = torch.empty(
  1058. self.embeddings_tensors.shape[0] + 1,
  1059. self.embeddings_tensors.shape[1],
  1060. hidden_states.shape[0],
  1061. dtype=self.embeddings_tensors.dtype,
  1062. device=self.embeddings_tensors.device,
  1063. )
  1064. torch.matmul(self.embeddings_tensors,
  1065. hidden_states.T,
  1066. out=lora_logits[:-1])
  1067. lora_logits[-1] = float("-inf")
  1068. lora_logits = lora_logits.mT
  1069. lora_logits = (lora_logits.reshape(
  1070. lora_logits.shape[0] * lora_logits.shape[1],
  1071. lora_logits.shape[2],
  1072. ).index_select(0,
  1073. self.indices_padded[:self.indices_len[2]]).nan_to_num_(
  1074. nan=float("-inf"),
  1075. posinf=float("inf"),
  1076. neginf=float("-inf")))
  1077. logits[:,
  1078. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  1079. lora_logits.shape[1]] = lora_logits
  1080. _apply_lora(
  1081. hidden_states,
  1082. self.lora_a_stacked,
  1083. self.lora_b_stacked,
  1084. self.indices[:self.indices_len[1]],
  1085. logits,
  1086. )
  1087. # Remove paddings in vocab (if any).
  1088. logits = logits[:, :self.base_layer.vocab_size]
  1089. return logits
  1090. def forward(self, *args, **kwargs):
  1091. return type(self.base_layer).forward(self, *args, **kwargs)
  1092. @classmethod
  1093. def can_replace_layer(cls, source_layer: nn.Module,
  1094. lora_config: LoRAConfig, packed_modules_list: List,
  1095. model_config: Optional[PretrainedConfig]) -> bool:
  1096. # Special handling for the LogitsProcessor.
  1097. return False
  1098. class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
  1099. """Implements RoPE-scaled embeddings with linear scaling for
  1100. multiple LoRA adapters with a specialized kernel.
  1101. Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
  1102. which can handle multi lora adapters in a specialied kernel.
  1103. """
  1104. def __init__(self, base_layer: RotaryEmbedding) -> None:
  1105. super().__init__()
  1106. self.base_layer = base_layer
  1107. # Lazily initialized
  1108. self.long_lora_indices: torch.Tensor
  1109. self.indices_len: List[int]
  1110. @property
  1111. def scaling_factors(self):
  1112. return self.base_layer.scaling_factors
  1113. @property
  1114. def rotary_dim(self):
  1115. return self.base_layer.rotary_dim
  1116. def create_lora_weights(
  1117. self,
  1118. max_loras: int,
  1119. lora_config: LoRAConfig,
  1120. model_config: Optional[PretrainedConfig] = None,
  1121. ) -> None:
  1122. scaling_factors = list(
  1123. lora_config.long_lora_scaling_factors
  1124. ) if lora_config.long_lora_scaling_factors else []
  1125. base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
  1126. self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
  1127. scaling_factors = sorted(
  1128. list(set([base_scaling_factor] + scaling_factors)))
  1129. self.base_layer = LinearScalingRotaryEmbedding(
  1130. self.base_layer.head_size,
  1131. self.base_layer.rotary_dim,
  1132. self.base_layer.max_position_embeddings,
  1133. self.base_layer.base,
  1134. self.base_layer.is_neox_style,
  1135. scaling_factors,
  1136. self.base_layer.dtype,
  1137. )
  1138. def reset_lora(self, index: int):
  1139. ...
  1140. def set_lora(
  1141. self,
  1142. index: int,
  1143. lora_a: torch.Tensor,
  1144. lora_b: torch.Tensor,
  1145. embeddings_tensor: Optional[torch.Tensor],
  1146. ):
  1147. ...
  1148. def set_mapping(
  1149. self,
  1150. base_indices: torch.Tensor,
  1151. sampler_indices: torch.Tensor,
  1152. sampler_indices_padded: torch.Tensor,
  1153. embeddings_indices: torch.Tensor,
  1154. long_lora_indices: torch.Tensor,
  1155. indices_len: List[int],
  1156. ):
  1157. self.long_lora_indices = long_lora_indices
  1158. self.indices_len = indices_len
  1159. def forward(
  1160. self,
  1161. positions: torch.Tensor,
  1162. query: torch.Tensor,
  1163. key: torch.Tensor,
  1164. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1165. return self.base_layer(
  1166. positions,
  1167. query,
  1168. key,
  1169. offsets=self.long_lora_indices[:self.indices_len[4]])
  1170. @property
  1171. def scaling_factor_to_offset(self) -> Dict[float, int]:
  1172. return self.base_layer.scaling_factor_to_offset
  1173. @classmethod
  1174. def can_replace_layer(cls, source_layer: nn.Module,
  1175. lora_config: LoRAConfig, packed_modules_list: List,
  1176. model_config: Optional[PretrainedConfig]) -> bool:
  1177. """Returns True if the layer can be replaced by this LoRA layer."""
  1178. return type(source_layer) is LinearScalingRotaryEmbedding or type(
  1179. source_layer) is RotaryEmbedding
  1180. def extra_repr(self) -> str:
  1181. return self.base_layer.extra_repr()