layers.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130
  1. # pylint: disable=unused-argument
  2. import inspect
  3. import math
  4. from dataclasses import dataclass
  5. from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from transformers import PretrainedConfig
  10. from aphrodite.common.config import LoRAConfig
  11. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size,
  13. split_tensor_along_last_dim,
  14. tensor_model_parallel_all_gather,
  15. tensor_model_parallel_all_reduce,
  16. tensor_model_parallel_gather)
  17. from aphrodite.lora.punica import add_lora, add_lora_slice, bgmv
  18. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  19. MergedColumnParallelLinear,
  20. QKVParallelLinear,
  21. RowParallelLinear)
  22. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  23. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  24. ParallelLMHead, VocabParallelEmbedding)
  25. if TYPE_CHECKING:
  26. pass
  27. def _get_lora_device(base_layer: nn.Module) -> torch.device:
  28. """Identify the device for positioning the LoRA tensors."""
  29. device = None
  30. try:
  31. device = base_layer.weight.device
  32. except AttributeError:
  33. try:
  34. linear_weights = base_layer.linear_weights
  35. if isinstance(linear_weights, dict):
  36. tensor_values = [
  37. v for v in linear_weights.values()
  38. if isinstance(v, torch.Tensor)
  39. ]
  40. if tensor_values:
  41. device = tensor_values[0].device
  42. except AttributeError:
  43. pass
  44. if device is None:
  45. raise ValueError(f"Base layer not supported: {base_layer}")
  46. return device
  47. def _apply_lora(
  48. x: torch.Tensor,
  49. lora_a_stacked: torch.Tensor,
  50. lora_b_stacked: torch.Tensor,
  51. indices: torch.Tensor,
  52. output: torch.Tensor,
  53. ):
  54. """Applies lora to each input.
  55. This method applies all loras to each input. It uses the
  56. indices vector to determine which lora yields the
  57. correct output. An index of -1 means no lora should be
  58. applied. This method adds the final lora results to the
  59. output.
  60. Input shapes:
  61. x: (batch_size, hidden_dim)
  62. lora_a_stacked: (num_loras, lora_rank, hidden_dim)
  63. lora_b_stacked: (num_loras, output_dim, lora_rank)
  64. indices: (batch_size)
  65. output: (batch_size, output_dim)
  66. """
  67. org_output = output
  68. x = x.view(-1, x.shape[-1])
  69. output = output.view(-1, output.shape[-1])
  70. indices = indices.view(-1)
  71. add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
  72. return output.view_as(org_output)
  73. def _apply_lora_packed_nslice(
  74. x: torch.Tensor,
  75. lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  76. lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  77. indices: torch.Tensor,
  78. output: torch.Tensor,
  79. output_slices: Tuple[int, ...],
  80. ):
  81. """Applies lora to each input.
  82. This method applies all loras to each input. It uses the
  83. indices vector to determine which lora yields the
  84. correct output. An index of -1 means no lora should be
  85. applied. This method adds the final lora results to the
  86. output.
  87. This method is used for layers that are composed of multiple sublayers
  88. (slices) packed together.
  89. Input shapes:
  90. x: (batch_size, hidden_dim)
  91. lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
  92. lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
  93. indices: (batch_size)
  94. output: (batch_size, q_slice_size + 2*kv_slice_size)
  95. output_slices: n-1 element tuple of (slice_size...),
  96. where n is number of slices
  97. """
  98. org_output = output
  99. x = x.view(-1, x.shape[-1])
  100. output = output.view(-1, output.shape[-1])
  101. indices = indices.view(-1)
  102. offset_left = 0
  103. for slice_idx in range(len(output_slices)):
  104. add_lora_slice(output, x, lora_a_stacked[slice_idx],
  105. lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
  106. output_slices[slice_idx])
  107. offset_left += output_slices[slice_idx]
  108. return output.view_as(org_output)
  109. @dataclass
  110. class LoRAMapping:
  111. # Per every token in input_ids:
  112. index_mapping: Tuple[int, ...]
  113. # Per sampled token:
  114. prompt_mapping: Tuple[int, ...]
  115. def __post_init__(self):
  116. self.index_mapping = tuple(self.index_mapping)
  117. self.prompt_mapping = tuple(self.prompt_mapping)
  118. class BaseLayerWithLoRA(nn.Module):
  119. def create_lora_weights(
  120. self,
  121. max_loras: int,
  122. lora_config: LoRAConfig,
  123. model_config: Optional[PretrainedConfig] = None) -> None:
  124. """Initializes lora matrices."""
  125. ...
  126. def reset_lora(self, index: int):
  127. """Resets the lora weights at index back to 0."""
  128. ...
  129. def set_lora(
  130. self,
  131. index: int,
  132. lora_a: torch.Tensor,
  133. lora_b: torch.Tensor,
  134. embeddings_tensor: Optional[torch.Tensor],
  135. ):
  136. """Overwrites lora tensors at index."""
  137. ...
  138. def set_mapping(
  139. self,
  140. base_indices: torch.Tensor,
  141. sampler_indices: torch.Tensor,
  142. sampler_indices_padded: torch.Tensor,
  143. embeddings_indices: torch.Tensor,
  144. indices_len: List[int],
  145. ):
  146. """Sets the mapping indices."""
  147. ...
  148. @classmethod
  149. def can_replace_layer(cls, source_layer: nn.Module,
  150. lora_config: LoRAConfig, packed_modules_list: List,
  151. model_config: Optional[PretrainedConfig]) -> bool:
  152. """Returns True if the layer can be replaced by this LoRA layer."""
  153. raise NotImplementedError
  154. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
  155. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  156. super().__init__()
  157. self.base_layer = base_layer
  158. def create_lora_weights(
  159. self,
  160. max_loras: int,
  161. lora_config: LoRAConfig,
  162. model_config: Optional[PretrainedConfig] = None) -> None:
  163. lora_vocab_start_idx = self.base_layer.org_vocab_size
  164. weights_idx = None
  165. if self.base_layer.vocab_end_index > lora_vocab_start_idx:
  166. # We can start adding lora weights
  167. weights_idx = max(
  168. lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
  169. self.embeddings_slice = (self.base_layer.vocab_start_index -
  170. self.base_layer.org_vocab_size +
  171. weights_idx,
  172. self.base_layer.vocab_end_index -
  173. self.base_layer.org_vocab_size)
  174. self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
  175. self.embeddings_weights.fill_(0)
  176. else:
  177. self.embeddings_slice = None
  178. self.embeddings_weights = None
  179. self.embeddings_tensors = torch.zeros(
  180. (
  181. max_loras,
  182. lora_config.lora_extra_vocab_size,
  183. self.base_layer.embedding_dim,
  184. ),
  185. dtype=self.base_layer.weight.dtype,
  186. device=self.base_layer.weight.device,
  187. )
  188. self.lora_a_stacked = torch.zeros(
  189. (
  190. max_loras,
  191. self.base_layer.org_vocab_size +
  192. lora_config.lora_extra_vocab_size,
  193. lora_config.max_lora_rank,
  194. ),
  195. dtype=lora_config.lora_dtype,
  196. device=self.base_layer.weight.device,
  197. )
  198. self.lora_b_stacked = torch.zeros(
  199. (
  200. max_loras,
  201. 1,
  202. self.base_layer.embedding_dim,
  203. lora_config.max_lora_rank,
  204. ),
  205. dtype=lora_config.lora_dtype,
  206. device=self.base_layer.weight.device,
  207. )
  208. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  209. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  210. self.lora_a_stacked.shape[2],
  211. )
  212. self.indices: Optional[torch.Tensor] = None
  213. self.indices_len: Optional[List[int]] = None
  214. self.embeddings_indices = None
  215. def reset_lora(self, index: int):
  216. self.lora_a_stacked[index] = 0
  217. self.lora_b_stacked[index] = 0
  218. self.embeddings_tensors[index] = 0
  219. def set_lora(
  220. self,
  221. index: int,
  222. lora_a: torch.Tensor,
  223. lora_b: torch.Tensor,
  224. embeddings_tensor: Optional[torch.Tensor],
  225. ):
  226. self.reset_lora(index)
  227. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  228. lora_a, non_blocking=True)
  229. self.lora_b_stacked[index,
  230. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  231. lora_b.T, non_blocking=True)
  232. if embeddings_tensor is not None:
  233. self.embeddings_tensors[
  234. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  235. shape[1]].copy_(embeddings_tensor, non_blocking=True)
  236. if self.embeddings_slice is not None:
  237. # TODO(yard1): Optimize this copy, we don't need to copy
  238. # everything, just the modified part
  239. embeddings = self.embeddings_tensors.view(
  240. self.embeddings_tensors.shape[0] *
  241. self.embeddings_tensors.shape[1],
  242. self.embeddings_tensors.shape[2]
  243. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  244. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  245. def set_mapping(
  246. self,
  247. base_indices: torch.Tensor,
  248. sampler_indices: torch.Tensor,
  249. sampler_indices_padded: torch.Tensor,
  250. embeddings_indices: torch.Tensor,
  251. indices_len: List[int],
  252. ):
  253. self.indices = base_indices
  254. self.embeddings_indices = embeddings_indices
  255. self.indices_len = indices_len
  256. def forward(self, x: torch.Tensor) -> torch.Tensor:
  257. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  258. embedding_len = self.indices_len[3]
  259. indices = self.embeddings_indices[1][:embedding_len].view_as(x)
  260. full_lora_a_embeddings = F.embedding(
  261. x + indices,
  262. self.lora_a_stacked_2d,
  263. )
  264. indices = self.embeddings_indices[0][:embedding_len].view_as(x)
  265. full_output = self.base_layer.forward(
  266. x.add_(indices * added_tokens_mask))
  267. full_output_org = full_output
  268. if full_output.ndim == 3:
  269. full_output = full_output.view(
  270. full_output.shape[0] * full_output.shape[1], -1)
  271. if full_lora_a_embeddings.ndim == 3:
  272. full_lora_a_embeddings = full_lora_a_embeddings.view(
  273. full_lora_a_embeddings.shape[0] *
  274. full_lora_a_embeddings.shape[1], -1)
  275. bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
  276. self.indices[:self.indices_len[0]], 0, 1.0)
  277. return full_output.view_as(full_output_org)
  278. @classmethod
  279. def can_replace_layer(cls, source_layer: nn.Module,
  280. lora_config: LoRAConfig, packed_modules_list: List,
  281. model_config: Optional[PretrainedConfig]) -> bool:
  282. return type(source_layer) is VocabParallelEmbedding
  283. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  284. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  285. super().__init__()
  286. self.base_layer = base_layer
  287. self.tp_size = get_tensor_model_parallel_world_size()
  288. def create_lora_weights(
  289. self,
  290. max_loras: int,
  291. lora_config: LoRAConfig,
  292. model_config: Optional[PretrainedConfig] = None) -> None:
  293. self.lora_a_stacked = torch.zeros(
  294. max_loras,
  295. 1,
  296. lora_config.max_lora_rank,
  297. self.base_layer.weight.shape[1],
  298. dtype=lora_config.lora_dtype,
  299. device=self.base_layer.weight.device,
  300. )
  301. self.lora_b_stacked = torch.zeros(
  302. max_loras,
  303. 1,
  304. self.base_layer.weight.shape[0],
  305. lora_config.max_lora_rank,
  306. dtype=lora_config.lora_dtype,
  307. device=self.base_layer.weight.device,
  308. )
  309. self.indices: Optional[torch.Tensor] = None
  310. self.indices_len: Optional[List[int]] = None
  311. self.output_dim = self.lora_b_stacked.shape[2]
  312. def reset_lora(self, index: int):
  313. self.lora_a_stacked[index] = 0
  314. self.lora_b_stacked[index] = 0
  315. def set_lora(
  316. self,
  317. index: int,
  318. lora_a: torch.Tensor,
  319. lora_b: torch.Tensor,
  320. embeddings_tensor: Optional[torch.Tensor],
  321. ):
  322. self.reset_lora(index)
  323. if self.tp_size > 1:
  324. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  325. shard_size = self.output_dim
  326. start_idx = tensor_model_parallel_rank * shard_size
  327. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  328. lora_b = lora_b[:, start_idx:end_idx]
  329. self.lora_a_stacked[index,
  330. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  331. lora_a.T, non_blocking=True)
  332. self.lora_b_stacked[index,
  333. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  334. lora_b.T, non_blocking=True)
  335. def set_mapping(
  336. self,
  337. base_indices: torch.Tensor,
  338. sampler_indices: torch.Tensor,
  339. sampler_indices_padded: torch.Tensor,
  340. embeddings_indices: torch.Tensor,
  341. indices_len: List[int],
  342. ):
  343. self.indices = base_indices
  344. self.indices_len = indices_len
  345. def apply(self, x: torch.Tensor,
  346. bias: Optional[torch.Tensor]) -> torch.Tensor:
  347. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  348. _apply_lora(
  349. x,
  350. self.lora_a_stacked,
  351. self.lora_b_stacked,
  352. self.indices[:self.indices_len[0]],
  353. output,
  354. )
  355. return output
  356. def forward(self, input_):
  357. """Forward of ColumnParallelLinear
  358. Args:
  359. input_: Tensor whose last dimension is `input_size`.
  360. Returns:
  361. - output
  362. - bias
  363. """
  364. bias = (self.base_layer.bias
  365. if not self.base_layer.skip_bias_add else None)
  366. # Matrix multiply.
  367. output_parallel = self.apply(input_, bias)
  368. if self.base_layer.gather_output:
  369. # All-gather across the partitions.
  370. output = tensor_model_parallel_all_gather(output_parallel)
  371. else:
  372. output = output_parallel
  373. output_bias = (self.base_layer.bias
  374. if self.base_layer.skip_bias_add else None)
  375. return output, output_bias
  376. @property
  377. def linear_weights(self):
  378. return self.base_layer.linear_weights
  379. @classmethod
  380. def can_replace_layer(cls, source_layer: nn.Module,
  381. lora_config: LoRAConfig, packed_modules_list: List,
  382. model_config: Optional[PretrainedConfig]) -> bool:
  383. return type(source_layer) is ColumnParallelLinear or (
  384. type(source_layer) is MergedColumnParallelLinear
  385. and len(packed_modules_list) == 1)
  386. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  387. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  388. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  389. This means we have 2 LoRAs, each applied to one half of the layer.
  390. Both slices must have the same size.
  391. """
  392. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  393. super().__init__(base_layer)
  394. def create_lora_weights(
  395. self,
  396. max_loras: int,
  397. lora_config: LoRAConfig,
  398. model_config: Optional[PretrainedConfig] = None) -> None:
  399. n_slices = 2
  400. if not (len(self.base_layer.output_sizes) == n_slices
  401. and self.base_layer.output_sizes[0]
  402. == self.base_layer.output_sizes[1]):
  403. raise ValueError(
  404. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  405. "the same size.")
  406. self.tp_size = get_tensor_model_parallel_world_size()
  407. device = _get_lora_device(self.base_layer)
  408. self.lora_a_stacked = tuple(
  409. torch.zeros(
  410. max_loras,
  411. 1,
  412. lora_config.max_lora_rank,
  413. self.base_layer.input_size,
  414. dtype=lora_config.lora_dtype,
  415. device=device,
  416. ) for _ in range(n_slices))
  417. self.lora_b_stacked = tuple(
  418. torch.zeros(
  419. max_loras,
  420. 1,
  421. self.base_layer.output_size // 2,
  422. lora_config.max_lora_rank,
  423. dtype=lora_config.lora_dtype,
  424. device=device,
  425. ) for _ in range(n_slices))
  426. self.indices: Optional[torch.Tensor] = None
  427. self.output_dim = self.lora_b_stacked[0].shape[2]
  428. def reset_lora(self, index: int):
  429. self.lora_a_stacked[0][index] = 0
  430. self.lora_a_stacked[1][index] = 0
  431. self.lora_b_stacked[0][index] = 0
  432. self.lora_b_stacked[1][index] = 0
  433. def set_lora(
  434. self,
  435. index: int,
  436. lora_a: torch.Tensor,
  437. lora_b: torch.Tensor,
  438. embeddings_tensor: Optional[torch.Tensor],
  439. ):
  440. self.reset_lora(index)
  441. if self.tp_size > 1:
  442. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  443. shard_size = self.output_dim
  444. start_idx = tensor_model_parallel_rank * shard_size
  445. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  446. lora_b = lora_b[0][:,
  447. start_idx:end_idx], lora_b[1][:,
  448. start_idx:end_idx]
  449. if lora_a[0] is not None:
  450. self.lora_a_stacked[0][
  451. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  452. lora_a[0].T, non_blocking=True)
  453. self.lora_b_stacked[0][
  454. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  455. lora_b[0].T, non_blocking=True)
  456. if lora_a[1] is not None:
  457. self.lora_a_stacked[1][
  458. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  459. lora_a[1].T, non_blocking=True)
  460. self.lora_b_stacked[1][
  461. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  462. lora_b[1].T, non_blocking=True)
  463. def apply(self, x: torch.Tensor,
  464. bias: Optional[torch.Tensor]) -> torch.Tensor:
  465. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  466. _apply_lora_packed_nslice(
  467. x,
  468. self.lora_a_stacked,
  469. self.lora_b_stacked,
  470. self.indices[:self.indices_len[0]],
  471. output,
  472. (self.output_dim, self.output_dim),
  473. )
  474. return output
  475. @classmethod
  476. def can_replace_layer(cls, source_layer: nn.Module,
  477. lora_config: LoRAConfig, packed_modules_list: List,
  478. model_config: Optional[PretrainedConfig]) -> bool:
  479. return type(source_layer) is MergedColumnParallelLinear and len(
  480. packed_modules_list) == 2
  481. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  482. """
  483. ColumnParallelLinear layer that is specifically designed for
  484. qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
  485. only contains a single LoRA within their qkv_proj layer.
  486. During inference with Tensor Parallel, the weights of lora_b
  487. must be accurately partitioned according to the respective ranks.
  488. Q slice may have different shape than K and V slices (which both have
  489. the same shape).
  490. """
  491. def __init__(self, base_layer: QKVParallelLinear) -> None:
  492. super().__init__(base_layer)
  493. self.tp_size = get_tensor_model_parallel_world_size()
  494. self.q_proj_total_size = (self.base_layer.total_num_heads *
  495. self.base_layer.head_size)
  496. self.q_proj_shard_size = (self.base_layer.num_heads *
  497. self.base_layer.head_size)
  498. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  499. self.base_layer.head_size)
  500. self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
  501. self.base_layer.head_size)
  502. def set_lora(
  503. self,
  504. index: int,
  505. lora_a: torch.Tensor,
  506. lora_b: torch.Tensor,
  507. embeddings_tensor: Optional[torch.Tensor],
  508. ):
  509. self.reset_lora(index)
  510. if self.tp_size > 1:
  511. tp_rank = get_tensor_model_parallel_rank()
  512. self.q_shard_id = tp_rank
  513. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  514. lora_b_q = lora_b[:, self.q_proj_shard_size *
  515. self.q_shard_id:self.q_proj_shard_size *
  516. (self.q_shard_id + 1)]
  517. k_offset = self.q_proj_total_size
  518. lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
  519. self.kv_shard_id:k_offset +
  520. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  521. v_offset = k_offset + self.kv_proj_total_size
  522. lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
  523. self.kv_shard_id:v_offset +
  524. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  525. lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
  526. self.lora_a_stacked[index,
  527. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  528. lora_a.T, non_blocking=True)
  529. self.lora_b_stacked[index,
  530. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  531. lora_b.T, non_blocking=True)
  532. @classmethod
  533. def can_replace_layer(cls, source_layer: nn.Module,
  534. lora_config: LoRAConfig, packed_modules_list: List,
  535. model_config: Optional[PretrainedConfig]) -> bool:
  536. return type(source_layer) is QKVParallelLinear and len(
  537. packed_modules_list) == 1
  538. class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  539. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  540. packed together in qkv proj fashion
  541. (q_proj + k_proj + v_proj -> qkv_proj).
  542. This means we have 3 LoRAs, each applied to one slice of the layer.
  543. Q slice may have different shape than K and V slices (which both have
  544. the same shape).
  545. """
  546. def __init__(self, base_layer: QKVParallelLinear) -> None:
  547. super().__init__(base_layer)
  548. def create_lora_weights(
  549. self,
  550. max_loras: int,
  551. lora_config: LoRAConfig,
  552. model_config: Optional[PretrainedConfig] = None) -> None:
  553. self.tp_size = get_tensor_model_parallel_world_size()
  554. tp_rank = get_tensor_model_parallel_rank()
  555. self.q_proj_shard_size = (self.base_layer.num_heads *
  556. self.base_layer.head_size)
  557. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  558. self.base_layer.head_size)
  559. self.q_shard_id = tp_rank
  560. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  561. device = _get_lora_device(self.base_layer)
  562. # q, k, v
  563. self.lora_a_stacked = (
  564. torch.zeros(
  565. max_loras,
  566. 1,
  567. lora_config.max_lora_rank,
  568. self.base_layer.input_size,
  569. dtype=lora_config.lora_dtype,
  570. device=device,
  571. ),
  572. torch.zeros(
  573. max_loras,
  574. 1,
  575. lora_config.max_lora_rank,
  576. self.base_layer.input_size,
  577. dtype=lora_config.lora_dtype,
  578. device=device,
  579. ),
  580. torch.zeros(
  581. max_loras,
  582. 1,
  583. lora_config.max_lora_rank,
  584. self.base_layer.input_size,
  585. dtype=lora_config.lora_dtype,
  586. device=device,
  587. ),
  588. )
  589. self.lora_b_stacked = (
  590. torch.zeros(
  591. max_loras,
  592. 1,
  593. self.q_proj_shard_size,
  594. lora_config.max_lora_rank,
  595. dtype=lora_config.lora_dtype,
  596. device=device,
  597. ),
  598. torch.zeros(
  599. max_loras,
  600. 1,
  601. self.kv_proj_shard_size,
  602. lora_config.max_lora_rank,
  603. dtype=lora_config.lora_dtype,
  604. device=device,
  605. ),
  606. torch.zeros(
  607. max_loras,
  608. 1,
  609. self.kv_proj_shard_size,
  610. lora_config.max_lora_rank,
  611. dtype=lora_config.lora_dtype,
  612. device=device,
  613. ),
  614. )
  615. self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
  616. self.kv_proj_shard_size)
  617. self.packed_indices: Optional[torch.Tensor] = None
  618. self.standard_indices: Optional[torch.Tensor] = None
  619. self.indices_len: Optional[List[int]] = None
  620. def reset_lora(self, index: int):
  621. self.lora_a_stacked[0][index] = 0
  622. self.lora_b_stacked[0][index] = 0
  623. self.lora_a_stacked[1][index] = 0
  624. self.lora_b_stacked[1][index] = 0
  625. self.lora_a_stacked[2][index] = 0
  626. self.lora_b_stacked[2][index] = 0
  627. def set_lora(
  628. self,
  629. index: int,
  630. lora_a: torch.Tensor,
  631. lora_b: torch.Tensor,
  632. embeddings_tensor: Optional[torch.Tensor],
  633. ):
  634. self.reset_lora(index)
  635. if self.tp_size > 1:
  636. if lora_b[0] is not None:
  637. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  638. self.q_shard_id:self.q_proj_shard_size *
  639. (self.q_shard_id + 1)]
  640. self.lora_b_stacked[0][
  641. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  642. lora_b_q.T, non_blocking=True)
  643. if lora_b[1] is not None:
  644. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  645. self.kv_shard_id:self.kv_proj_shard_size *
  646. (self.kv_shard_id + 1)]
  647. self.lora_b_stacked[1][
  648. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  649. lora_b_k.T, non_blocking=True)
  650. if lora_b[2] is not None:
  651. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  652. self.kv_shard_id:self.kv_proj_shard_size *
  653. (self.kv_shard_id + 1)]
  654. self.lora_b_stacked[2][
  655. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  656. lora_b_v.T, non_blocking=True)
  657. else:
  658. if lora_b[0] is not None:
  659. self.lora_b_stacked[0][
  660. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  661. lora_b[0].T, non_blocking=True)
  662. if lora_b[1] is not None:
  663. self.lora_b_stacked[1][
  664. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  665. lora_b[1].T, non_blocking=True)
  666. if lora_b[2] is not None:
  667. self.lora_b_stacked[2][
  668. index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
  669. lora_b[2].T, non_blocking=True)
  670. if lora_a[0] is not None:
  671. self.lora_a_stacked[0][
  672. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  673. lora_a[0].T, non_blocking=True)
  674. if lora_a[1] is not None:
  675. self.lora_a_stacked[1][
  676. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  677. lora_a[1].T, non_blocking=True)
  678. if lora_a[2] is not None:
  679. self.lora_a_stacked[2][
  680. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  681. lora_a[2].T, non_blocking=True)
  682. def apply(self, x: torch.Tensor,
  683. bias: Optional[torch.Tensor]) -> torch.Tensor:
  684. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  685. _apply_lora_packed_nslice(
  686. x,
  687. self.lora_a_stacked,
  688. self.lora_b_stacked,
  689. self.indices[:self.indices_len[0]],
  690. output,
  691. self.output_slices,
  692. )
  693. return output
  694. @classmethod
  695. def can_replace_layer(cls, source_layer: nn.Module,
  696. lora_config: LoRAConfig, packed_modules_list: List,
  697. model_config: Optional[PretrainedConfig]) -> bool:
  698. return type(source_layer) is QKVParallelLinear and len(
  699. packed_modules_list) == 3
  700. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  701. def __init__(self, base_layer: RowParallelLinear) -> None:
  702. super().__init__()
  703. self.base_layer = base_layer
  704. def create_lora_weights(
  705. self,
  706. max_loras: int,
  707. lora_config: LoRAConfig,
  708. model_config: Optional[PretrainedConfig] = None) -> None:
  709. device = _get_lora_device(self.base_layer)
  710. self.lora_a_stacked = torch.zeros(
  711. (
  712. max_loras,
  713. 1,
  714. lora_config.max_lora_rank,
  715. self.base_layer.input_size,
  716. ),
  717. dtype=lora_config.lora_dtype,
  718. device=device,
  719. )
  720. self.lora_b_stacked = torch.zeros(
  721. (
  722. max_loras,
  723. 1,
  724. self.base_layer.output_size,
  725. lora_config.max_lora_rank,
  726. ),
  727. dtype=lora_config.lora_dtype,
  728. device=device,
  729. )
  730. self.indices: Optional[torch.Tensor] = None
  731. self.indices_len: Optional[List[int]] = None
  732. def reset_lora(self, index: int):
  733. self.lora_a_stacked[index] = 0
  734. self.lora_b_stacked[index] = 0
  735. def set_lora(
  736. self,
  737. index: int,
  738. lora_a: torch.Tensor,
  739. lora_b: torch.Tensor,
  740. embeddings_tensor: Optional[torch.Tensor],
  741. ):
  742. self.reset_lora(index)
  743. if self.base_layer.tp_size > 1:
  744. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  745. shard_size = self.base_layer.weight.shape[1]
  746. start_idx = tensor_model_parallel_rank * shard_size
  747. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  748. lora_a = lora_a[start_idx:end_idx, :]
  749. self.lora_a_stacked[index,
  750. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  751. lora_a.T, non_blocking=True)
  752. self.lora_b_stacked[index,
  753. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  754. lora_b.T, non_blocking=True)
  755. def set_mapping(
  756. self,
  757. base_indices: torch.Tensor,
  758. sampler_indices: torch.Tensor,
  759. sampler_indices_padded: torch.Tensor,
  760. embeddings_indices: torch.Tensor,
  761. indices_len: List[int],
  762. ):
  763. self.indices = base_indices
  764. self.indices_len = indices_len
  765. def apply(self, x: torch.Tensor) -> torch.Tensor:
  766. output = self.base_layer.quant_method.apply(self.base_layer, x)
  767. _apply_lora(
  768. x,
  769. self.lora_a_stacked,
  770. self.lora_b_stacked,
  771. self.indices[:self.indices_len[0]],
  772. output,
  773. )
  774. return output
  775. def forward(self, input_):
  776. """Forward of RowParallelLinear
  777. Args:
  778. input_: tensor whose last dimension is `input_size`. If
  779. `input_is_parallel` is set, then the last dimension
  780. is `input_size // tp_size`.
  781. Returns:
  782. - output
  783. - bias
  784. """
  785. # Set up backprop all-reduce.
  786. if self.base_layer.input_is_parallel:
  787. input_parallel = input_
  788. else:
  789. # TODO: simplify code below
  790. tp_rank = get_tensor_model_parallel_rank()
  791. splitted_input = split_tensor_along_last_dim(
  792. input_, num_partitions=self.base_layer.tp_size)
  793. input_parallel = splitted_input[tp_rank].contiguous()
  794. # Matrix multiply.
  795. output_parallel = self.apply(input_parallel)
  796. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  797. output_ = tensor_model_parallel_all_reduce(output_parallel)
  798. else:
  799. output_ = output_parallel
  800. if not self.base_layer.skip_bias_add:
  801. output = (output_ + self.base_layer.bias
  802. if self.base_layer.bias is not None else output_)
  803. output_bias = None
  804. else:
  805. output = output_
  806. output_bias = self.base_layer.bias
  807. return output, output_bias
  808. @property
  809. def weight(self):
  810. return self.base_layer.weight
  811. @classmethod
  812. def can_replace_layer(cls, source_layer: nn.Module,
  813. lora_config: LoRAConfig, packed_modules_list: List,
  814. model_config: Optional[PretrainedConfig]) -> bool:
  815. return type(source_layer) is RowParallelLinear
  816. class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
  817. def __init__(
  818. self,
  819. base_layer: LogitsProcessor,
  820. hidden_size: int,
  821. dtype: torch.dtype,
  822. device: torch.device,
  823. ) -> None:
  824. super().__init__()
  825. self.base_layer = base_layer
  826. self.hidden_size = hidden_size
  827. self.dtype = dtype
  828. self.device = device
  829. @property
  830. def logits_as_input(self):
  831. return self.base_layer.logits_as_input
  832. @property
  833. def vocab_size(self):
  834. return self.base_layer.vocab_size
  835. @property
  836. def scale(self):
  837. return self.base_layer.scale
  838. @property
  839. def org_vocab_size(self):
  840. return self.base_layer.org_vocab_size
  841. @property
  842. def include_gpu_probs_tensor(self):
  843. return self.base_layer.include_gpu_probs_tensor
  844. def create_lora_weights(
  845. self,
  846. max_loras: int,
  847. lora_config: LoRAConfig,
  848. model_config: Optional[PretrainedConfig] = None,
  849. ) -> None:
  850. # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
  851. if 32000 < self.base_layer.vocab_size > 128512:
  852. raise ValueError("When using LoRA, vocab size must be "
  853. "32000 >= vocab_size <= 128512")
  854. self.lora_a_stacked = torch.zeros(
  855. (
  856. max_loras,
  857. 1,
  858. lora_config.max_lora_rank,
  859. self.hidden_size,
  860. ),
  861. dtype=lora_config.lora_dtype,
  862. device=self.device,
  863. )
  864. self.lora_b_stacked = torch.zeros(
  865. (
  866. max_loras,
  867. 1,
  868. # Pad for kernel compatibility
  869. math.ceil(self.base_layer.vocab_size /
  870. lora_config.lora_vocab_padding_size) *
  871. lora_config.lora_vocab_padding_size,
  872. lora_config.max_lora_rank,
  873. ),
  874. dtype=lora_config.lora_dtype,
  875. device=self.device,
  876. )
  877. self.embeddings_tensors = torch.full(
  878. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  879. fill_value=float("-inf"),
  880. dtype=self.dtype,
  881. device=self.device,
  882. )
  883. self.indices = None
  884. self.indices_padded = None
  885. self.indices_len = None
  886. def reset_lora(self, index: int):
  887. self.lora_a_stacked[index] = 0
  888. self.lora_b_stacked[index] = 0
  889. self.embeddings_tensors[index] = float("-inf")
  890. def set_lora(
  891. self,
  892. index: int,
  893. lora_a: torch.Tensor,
  894. lora_b: torch.Tensor,
  895. embeddings_tensor: Optional[torch.Tensor],
  896. ):
  897. self.reset_lora(index)
  898. self.lora_a_stacked[index,
  899. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  900. lora_a.T, non_blocking=True)
  901. self.lora_b_stacked[index,
  902. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  903. lora_b.T, non_blocking=True)
  904. if embeddings_tensor is not None:
  905. self.embeddings_tensors[
  906. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  907. shape[1], ] = embeddings_tensor
  908. def set_mapping(
  909. self,
  910. base_indices: torch.Tensor,
  911. sampler_indices: torch.Tensor,
  912. sampler_indices_padded: torch.Tensor,
  913. embeddings_indices: torch.Tensor,
  914. indices_len: List[int],
  915. ):
  916. self.indices = sampler_indices
  917. self.indices_padded = sampler_indices_padded
  918. self.indices_len = indices_len
  919. def _get_logits(
  920. self,
  921. hidden_states: torch.Tensor,
  922. lm_head: torch.Tensor,
  923. embedding_bias: Optional[torch.Tensor] = None,
  924. ) -> Optional[torch.Tensor]:
  925. # Get the logits for the next tokens.
  926. logits = lm_head(hidden_states)
  927. if embedding_bias is not None:
  928. logits += embedding_bias
  929. logits = tensor_model_parallel_gather(logits)
  930. if logits is None:
  931. return None
  932. lora_logits = torch.empty(
  933. self.embeddings_tensors.shape[0] + 1,
  934. self.embeddings_tensors.shape[1],
  935. hidden_states.shape[0],
  936. dtype=self.embeddings_tensors.dtype,
  937. device=self.embeddings_tensors.device,
  938. )
  939. torch.matmul(self.embeddings_tensors,
  940. hidden_states.T,
  941. out=lora_logits[:-1])
  942. lora_logits[-1] = float("-inf")
  943. lora_logits = lora_logits.mT
  944. lora_logits = (lora_logits.reshape(
  945. lora_logits.shape[0] * lora_logits.shape[1],
  946. lora_logits.shape[2],
  947. ).index_select(0,
  948. self.indices_padded[:self.indices_len[2]]).nan_to_num_(
  949. nan=float("-inf"),
  950. posinf=float("inf"),
  951. neginf=float("-inf")))
  952. logits[:,
  953. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  954. lora_logits.shape[1]] = lora_logits
  955. _apply_lora(
  956. hidden_states,
  957. self.lora_a_stacked,
  958. self.lora_b_stacked,
  959. self.indices[:self.indices_len[1]],
  960. logits,
  961. )
  962. # Remove paddings in vocab (if any).
  963. logits = logits[:, :self.base_layer.vocab_size]
  964. return logits
  965. def forward(self, *args, **kwargs):
  966. return type(self.base_layer).forward(self, *args, **kwargs)
  967. @classmethod
  968. def can_replace_layer(cls, source_layer: nn.Module,
  969. lora_config: LoRAConfig, packed_modules_list: List,
  970. model_config: Optional[PretrainedConfig]) -> bool:
  971. # Special handling for the LogitsProcessor.
  972. return False
  973. _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
  974. cls
  975. for cls in globals().values() if inspect.isclass(cls)
  976. and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
  977. }
  978. def from_layer(layer: nn.Module,
  979. max_loras: int,
  980. lora_config: LoRAConfig,
  981. packed_modules_list: List,
  982. model_config: Optional[PretrainedConfig] = None) -> nn.Module:
  983. for lora_cls in _all_lora_classes:
  984. if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
  985. model_config):
  986. ret = lora_cls(layer)
  987. ret.create_lora_weights(max_loras, lora_config, model_config)
  988. return ret
  989. return layer
  990. def from_layer_logits_processor(
  991. layer: LogitsProcessor,
  992. lm_head: ParallelLMHead,
  993. max_loras: int,
  994. lora_config: LoRAConfig,
  995. model_config: Optional[PretrainedConfig] = None,
  996. ) -> LogitsProcessorWithLoRA:
  997. ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
  998. lm_head.weight.dtype, lm_head.weight.device)
  999. ret.create_lora_weights(max_loras, lora_config, model_config)
  1000. return ret