layers.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130
  1. # pylint: disable=unused-argument
  2. import inspect
  3. import math
  4. from dataclasses import dataclass
  5. from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from transformers import PretrainedConfig
  10. from aphrodite.common.config import LoRAConfig
  11. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size,
  13. split_tensor_along_last_dim,
  14. tensor_model_parallel_all_gather,
  15. tensor_model_parallel_all_reduce,
  16. tensor_model_parallel_gather)
  17. from aphrodite.lora.punica import add_lora, add_lora_slice, bgmv
  18. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  19. MergedColumnParallelLinear,
  20. QKVParallelLinear,
  21. RowParallelLinear)
  22. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  23. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  24. ParallelLMHead, VocabParallelEmbedding)
  25. if TYPE_CHECKING:
  26. pass
  27. def _get_lora_device(base_layer: nn.Module) -> torch.device:
  28. """Identify the device for positioning the LoRA tensors."""
  29. device = None
  30. try:
  31. device = base_layer.weight.device
  32. except AttributeError:
  33. try:
  34. linear_weights = base_layer.linear_weights
  35. if isinstance(linear_weights, dict):
  36. tensor_values = [
  37. v for v in linear_weights.values()
  38. if isinstance(v, torch.Tensor)
  39. ]
  40. if tensor_values:
  41. device = tensor_values[0].device
  42. except AttributeError:
  43. pass
  44. if device is None:
  45. raise ValueError(f"Base layer not supported: {base_layer}")
  46. return device
  47. def _apply_lora(
  48. x: torch.Tensor,
  49. lora_a_stacked: torch.Tensor,
  50. lora_b_stacked: torch.Tensor,
  51. indices: torch.Tensor,
  52. output: torch.Tensor,
  53. ):
  54. """Applies lora to each input.
  55. This method applies all loras to each input. It uses the
  56. indices vector to determine which lora yields the
  57. correct output. An index of -1 means no lora should be
  58. applied. This method adds the final lora results to the
  59. output.
  60. Input shapes:
  61. x: (batch_size, hidden_dim)
  62. lora_a_stacked: (num_loras, lora_rank, hidden_dim)
  63. lora_b_stacked: (num_loras, output_dim, lora_rank)
  64. indices: (batch_size)
  65. output: (batch_size, output_dim)
  66. """
  67. org_output = output
  68. x = x.view(-1, x.shape[-1])
  69. output = output.view(-1, output.shape[-1])
  70. indices = indices.view(-1)
  71. add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
  72. return output.view_as(org_output)
  73. def _apply_lora_packed_nslice(
  74. x: torch.Tensor,
  75. lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  76. lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  77. indices: torch.Tensor,
  78. output: torch.Tensor,
  79. output_slices: Tuple[int, ...],
  80. ):
  81. """Applies lora to each input.
  82. This method applies all loras to each input. It uses the
  83. indices vector to determine which lora yields the
  84. correct output. An index of -1 means no lora should be
  85. applied. This method adds the final lora results to the
  86. output.
  87. This method is used for layers that are composed of multiple sublayers
  88. (slices) packed together.
  89. Input shapes:
  90. x: (batch_size, hidden_dim)
  91. lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
  92. lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
  93. indices: (batch_size)
  94. output: (batch_size, q_slice_size + 2*kv_slice_size)
  95. output_slices: n-1 element tuple of (slice_size...),
  96. where n is number of slices
  97. """
  98. org_output = output
  99. x = x.view(-1, x.shape[-1])
  100. output = output.view(-1, output.shape[-1])
  101. indices = indices.view(-1)
  102. offset_left = 0
  103. for slice_idx in range(len(output_slices)):
  104. add_lora_slice(output, x, lora_a_stacked[slice_idx],
  105. lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
  106. output_slices[slice_idx])
  107. offset_left += output_slices[slice_idx]
  108. return output.view_as(org_output)
  109. @dataclass
  110. class LoRAMapping:
  111. # Per every token in input_ids:
  112. index_mapping: Tuple[int, ...]
  113. # Per sampled token:
  114. prompt_mapping: Tuple[int, ...]
  115. def __post_init__(self):
  116. self.index_mapping = tuple(self.index_mapping)
  117. self.prompt_mapping = tuple(self.prompt_mapping)
  118. class BaseLayerWithLoRA(nn.Module):
  119. def create_lora_weights(
  120. self,
  121. max_loras: int,
  122. lora_config: LoRAConfig,
  123. model_config: Optional[PretrainedConfig] = None) -> None:
  124. """Initializes lora matrices."""
  125. ...
  126. def reset_lora(self, index: int):
  127. """Resets the lora weights at index back to 0."""
  128. ...
  129. def set_lora(
  130. self,
  131. index: int,
  132. lora_a: torch.Tensor,
  133. lora_b: torch.Tensor,
  134. embeddings_tensor: Optional[torch.Tensor],
  135. ):
  136. """Overwrites lora tensors at index."""
  137. ...
  138. def set_mapping(
  139. self,
  140. base_indices: torch.Tensor,
  141. sampler_indices: torch.Tensor,
  142. sampler_indices_padded: torch.Tensor,
  143. embeddings_indices: torch.Tensor,
  144. indices_len: List[int],
  145. ):
  146. """Sets the mapping indices."""
  147. ...
  148. @classmethod
  149. def can_replace_layer(cls, source_layer: nn.Module,
  150. lora_config: LoRAConfig, packed_modules_list: List,
  151. model_config: Optional[PretrainedConfig]) -> bool:
  152. """Returns True if the layer can be replaced by this LoRA layer."""
  153. raise NotImplementedError
  154. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
  155. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  156. super().__init__()
  157. self.base_layer = base_layer
  158. def create_lora_weights(
  159. self,
  160. max_loras: int,
  161. lora_config: LoRAConfig,
  162. model_config: Optional[PretrainedConfig] = None) -> None:
  163. lora_vocab_start_idx = self.base_layer.org_vocab_size
  164. weights_idx = None
  165. if self.base_layer.vocab_end_index > lora_vocab_start_idx:
  166. # We can start adding lora weights
  167. weights_idx = max(
  168. lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
  169. self.embeddings_slice = (self.base_layer.vocab_start_index -
  170. self.base_layer.org_vocab_size +
  171. weights_idx,
  172. self.base_layer.vocab_end_index -
  173. self.base_layer.org_vocab_size)
  174. self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
  175. self.embeddings_weights.fill_(0)
  176. else:
  177. self.embeddings_slice = None
  178. self.embeddings_weights = None
  179. self.embeddings_tensors = torch.zeros(
  180. (
  181. max_loras,
  182. lora_config.lora_extra_vocab_size,
  183. self.base_layer.embedding_dim,
  184. ),
  185. dtype=self.base_layer.weight.dtype,
  186. device=self.base_layer.weight.device,
  187. )
  188. self.lora_a_stacked = torch.zeros(
  189. (
  190. max_loras,
  191. self.base_layer.org_vocab_size +
  192. lora_config.lora_extra_vocab_size,
  193. lora_config.max_lora_rank,
  194. ),
  195. dtype=lora_config.lora_dtype,
  196. device=self.base_layer.weight.device,
  197. )
  198. self.lora_b_stacked = torch.zeros(
  199. (
  200. max_loras,
  201. 1,
  202. self.base_layer.embedding_dim,
  203. lora_config.max_lora_rank,
  204. ),
  205. dtype=lora_config.lora_dtype,
  206. device=self.base_layer.weight.device,
  207. )
  208. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  209. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  210. self.lora_a_stacked.shape[2],
  211. )
  212. self.indices: Optional[torch.Tensor] = None
  213. self.indices_len: Optional[List[int]] = None
  214. self.embeddings_indices = None
  215. def reset_lora(self, index: int):
  216. self.lora_a_stacked[index] = 0
  217. self.lora_b_stacked[index] = 0
  218. self.embeddings_tensors[index] = 0
  219. def set_lora(
  220. self,
  221. index: int,
  222. lora_a: torch.Tensor,
  223. lora_b: torch.Tensor,
  224. embeddings_tensor: Optional[torch.Tensor],
  225. ):
  226. self.reset_lora(index)
  227. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  228. lora_a, non_blocking=True)
  229. self.lora_b_stacked[index,
  230. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  231. lora_b.T, non_blocking=True)
  232. if embeddings_tensor is not None:
  233. self.embeddings_tensors[
  234. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  235. shape[1]].copy_(embeddings_tensor, non_blocking=True)
  236. if self.embeddings_slice is not None:
  237. # TODO(yard1): Optimize this copy, we don't need to copy
  238. # everything, just the modified part
  239. embeddings = self.embeddings_tensors.view(
  240. self.embeddings_tensors.shape[0] *
  241. self.embeddings_tensors.shape[1],
  242. self.embeddings_tensors.shape[2]
  243. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  244. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  245. def set_mapping(
  246. self,
  247. base_indices: torch.Tensor,
  248. sampler_indices: torch.Tensor,
  249. sampler_indices_padded: torch.Tensor,
  250. embeddings_indices: torch.Tensor,
  251. indices_len: List[int],
  252. ):
  253. self.indices = base_indices
  254. self.embeddings_indices = embeddings_indices
  255. self.indices_len = indices_len
  256. def forward(self, x: torch.Tensor) -> torch.Tensor:
  257. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  258. embedding_len = self.indices_len[3]
  259. indices = self.embeddings_indices[1][:embedding_len].view_as(x)
  260. full_lora_a_embeddings = F.embedding(
  261. x + indices,
  262. self.lora_a_stacked_2d,
  263. )
  264. indices = self.embeddings_indices[0][:embedding_len].view_as(x)
  265. full_output = self.base_layer.forward(
  266. x.add_(indices * added_tokens_mask))
  267. full_output_org = full_output
  268. if full_output.ndim == 3:
  269. full_output = full_output.view(
  270. full_output.shape[0] * full_output.shape[1], -1)
  271. if full_lora_a_embeddings.ndim == 3:
  272. full_lora_a_embeddings = full_lora_a_embeddings.view(
  273. full_lora_a_embeddings.shape[0] *
  274. full_lora_a_embeddings.shape[1], -1)
  275. bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
  276. self.indices[:self.indices_len[0]], 0, 1.0)
  277. return full_output.view_as(full_output_org)
  278. @classmethod
  279. def can_replace_layer(cls, source_layer: nn.Module,
  280. lora_config: LoRAConfig, packed_modules_list: List,
  281. model_config: Optional[PretrainedConfig]) -> bool:
  282. return type(source_layer) is VocabParallelEmbedding
  283. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  284. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  285. super().__init__()
  286. self.base_layer = base_layer
  287. self.tp_size = get_tensor_model_parallel_world_size()
  288. def create_lora_weights(
  289. self,
  290. max_loras: int,
  291. lora_config: LoRAConfig,
  292. model_config: Optional[PretrainedConfig] = None) -> None:
  293. self.lora_a_stacked = torch.zeros(
  294. max_loras,
  295. 1,
  296. lora_config.max_lora_rank,
  297. self.base_layer.weight.shape[1],
  298. dtype=lora_config.lora_dtype,
  299. device=self.base_layer.weight.device,
  300. )
  301. self.lora_b_stacked = torch.zeros(
  302. max_loras,
  303. 1,
  304. self.base_layer.weight.shape[0],
  305. lora_config.max_lora_rank,
  306. dtype=lora_config.lora_dtype,
  307. device=self.base_layer.weight.device,
  308. )
  309. self.indices: Optional[torch.Tensor] = None
  310. self.indices_len: Optional[List[int]] = None
  311. self.output_dim = self.lora_b_stacked.shape[2]
  312. def reset_lora(self, index: int):
  313. self.lora_a_stacked[index] = 0
  314. self.lora_b_stacked[index] = 0
  315. def set_lora(
  316. self,
  317. index: int,
  318. lora_a: torch.Tensor,
  319. lora_b: torch.Tensor,
  320. embeddings_tensor: Optional[torch.Tensor],
  321. ):
  322. self.reset_lora(index)
  323. if self.tp_size > 1:
  324. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  325. shard_size = self.output_dim
  326. start_idx = tensor_model_parallel_rank * shard_size
  327. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  328. lora_b = lora_b[:, start_idx:end_idx]
  329. self.lora_a_stacked[index,
  330. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  331. lora_a.T, non_blocking=True)
  332. self.lora_b_stacked[index,
  333. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  334. lora_b.T, non_blocking=True)
  335. def set_mapping(
  336. self,
  337. base_indices: torch.Tensor,
  338. sampler_indices: torch.Tensor,
  339. sampler_indices_padded: torch.Tensor,
  340. embeddings_indices: torch.Tensor,
  341. indices_len: List[int],
  342. ):
  343. self.indices = base_indices
  344. self.indices_len = indices_len
  345. def apply_weights(self, x: torch.Tensor,
  346. bias: Optional[torch.Tensor]) -> torch.Tensor:
  347. output = self.base_layer.linear_method.apply_weights(
  348. self.base_layer, x, bias)
  349. _apply_lora(
  350. x,
  351. self.lora_a_stacked,
  352. self.lora_b_stacked,
  353. self.indices[:self.indices_len[0]],
  354. output,
  355. )
  356. return output
  357. def forward(self, input_):
  358. """Forward of ColumnParallelLinear
  359. Args:
  360. input_: Tensor whose last dimension is `input_size`.
  361. Returns:
  362. - output
  363. - bias
  364. """
  365. bias = (self.base_layer.bias
  366. if not self.base_layer.skip_bias_add else None)
  367. # Matrix multiply.
  368. output_parallel = self.apply_weights(input_, bias)
  369. if self.base_layer.gather_output:
  370. # All-gather across the partitions.
  371. output = tensor_model_parallel_all_gather(output_parallel)
  372. else:
  373. output = output_parallel
  374. output_bias = (self.base_layer.bias
  375. if self.base_layer.skip_bias_add else None)
  376. return output, output_bias
  377. @classmethod
  378. def can_replace_layer(cls, source_layer: nn.Module,
  379. lora_config: LoRAConfig, packed_modules_list: List,
  380. model_config: Optional[PretrainedConfig]) -> bool:
  381. return type(source_layer) is ColumnParallelLinear or (
  382. type(source_layer) is MergedColumnParallelLinear
  383. and len(packed_modules_list) == 1)
  384. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  385. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  386. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  387. This means we have 2 LoRAs, each applied to one half of the layer.
  388. Both slices must have the same size.
  389. """
  390. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  391. super().__init__(base_layer)
  392. def create_lora_weights(
  393. self,
  394. max_loras: int,
  395. lora_config: LoRAConfig,
  396. model_config: Optional[PretrainedConfig] = None) -> None:
  397. n_slices = 2
  398. if not (len(self.base_layer.output_sizes) == n_slices
  399. and self.base_layer.output_sizes[0]
  400. == self.base_layer.output_sizes[1]):
  401. raise ValueError(
  402. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  403. "the same size.")
  404. self.tp_size = get_tensor_model_parallel_world_size()
  405. device = _get_lora_device(self.base_layer)
  406. self.lora_a_stacked = tuple(
  407. torch.zeros(
  408. max_loras,
  409. 1,
  410. lora_config.max_lora_rank,
  411. self.base_layer.input_size,
  412. dtype=lora_config.lora_dtype,
  413. device=device,
  414. ) for _ in range(n_slices))
  415. self.lora_b_stacked = tuple(
  416. torch.zeros(
  417. max_loras,
  418. 1,
  419. self.base_layer.output_size // 2,
  420. lora_config.max_lora_rank,
  421. dtype=lora_config.lora_dtype,
  422. device=device,
  423. ) for _ in range(n_slices))
  424. self.indices: Optional[torch.Tensor] = None
  425. self.output_dim = self.lora_b_stacked[0].shape[2]
  426. def reset_lora(self, index: int):
  427. self.lora_a_stacked[0][index] = 0
  428. self.lora_a_stacked[1][index] = 0
  429. self.lora_b_stacked[0][index] = 0
  430. self.lora_b_stacked[1][index] = 0
  431. def set_lora(
  432. self,
  433. index: int,
  434. lora_a: torch.Tensor,
  435. lora_b: torch.Tensor,
  436. embeddings_tensor: Optional[torch.Tensor],
  437. ):
  438. self.reset_lora(index)
  439. if self.tp_size > 1:
  440. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  441. shard_size = self.output_dim
  442. start_idx = tensor_model_parallel_rank * shard_size
  443. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  444. lora_b = lora_b[0][:,
  445. start_idx:end_idx], lora_b[1][:,
  446. start_idx:end_idx]
  447. if lora_a[0] is not None:
  448. self.lora_a_stacked[0][
  449. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  450. lora_a[0].T, non_blocking=True)
  451. self.lora_b_stacked[0][
  452. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  453. lora_b[0].T, non_blocking=True)
  454. if lora_a[1] is not None:
  455. self.lora_a_stacked[1][
  456. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  457. lora_a[1].T, non_blocking=True)
  458. self.lora_b_stacked[1][
  459. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  460. lora_b[1].T, non_blocking=True)
  461. def apply_weights(self, x: torch.Tensor,
  462. bias: Optional[torch.Tensor]) -> torch.Tensor:
  463. output = self.base_layer.linear_method.apply_weights(
  464. self.base_layer, x, bias)
  465. _apply_lora_packed_nslice(
  466. x,
  467. self.lora_a_stacked,
  468. self.lora_b_stacked,
  469. self.indices[:self.indices_len[0]],
  470. output,
  471. (self.output_dim, self.output_dim),
  472. )
  473. return output
  474. @classmethod
  475. def can_replace_layer(cls, source_layer: nn.Module,
  476. lora_config: LoRAConfig, packed_modules_list: List,
  477. model_config: Optional[PretrainedConfig]) -> bool:
  478. return type(source_layer) is MergedColumnParallelLinear and len(
  479. packed_modules_list) == 2
  480. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  481. """
  482. ColumnParallelLinear layer that is specifically designed for
  483. qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
  484. only contains a single LoRA within their qkv_proj layer.
  485. During inference with Tensor Parallel, the weights of lora_b
  486. must be accurately partitioned according to the respective ranks.
  487. Q slice may have different shape than K and V slices (which both have
  488. the same shape).
  489. """
  490. def __init__(self, base_layer: QKVParallelLinear) -> None:
  491. super().__init__(base_layer)
  492. self.tp_size = get_tensor_model_parallel_world_size()
  493. self.q_proj_total_size = (self.base_layer.total_num_heads *
  494. self.base_layer.head_size)
  495. self.q_proj_shard_size = (self.base_layer.num_heads *
  496. self.base_layer.head_size)
  497. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  498. self.base_layer.head_size)
  499. self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
  500. self.base_layer.head_size)
  501. def set_lora(
  502. self,
  503. index: int,
  504. lora_a: torch.Tensor,
  505. lora_b: torch.Tensor,
  506. embeddings_tensor: Optional[torch.Tensor],
  507. ):
  508. self.reset_lora(index)
  509. if self.tp_size > 1:
  510. tp_rank = get_tensor_model_parallel_rank()
  511. self.q_shard_id = tp_rank
  512. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  513. lora_b_q = lora_b[:, self.q_proj_shard_size *
  514. self.q_shard_id:self.q_proj_shard_size *
  515. (self.q_shard_id + 1)]
  516. k_offset = self.q_proj_total_size
  517. lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
  518. self.kv_shard_id:k_offset +
  519. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  520. v_offset = k_offset + self.kv_proj_total_size
  521. lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
  522. self.kv_shard_id:v_offset +
  523. self.kv_proj_shard_size * (self.kv_shard_id + 1)]
  524. lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
  525. self.lora_a_stacked[index,
  526. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  527. lora_a.T, non_blocking=True)
  528. self.lora_b_stacked[index,
  529. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  530. lora_b.T, non_blocking=True)
  531. @classmethod
  532. def can_replace_layer(cls, source_layer: nn.Module,
  533. lora_config: LoRAConfig, packed_modules_list: List,
  534. model_config: Optional[PretrainedConfig]) -> bool:
  535. return type(source_layer) is QKVParallelLinear and len(
  536. packed_modules_list) == 1
  537. class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  538. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  539. packed together in qkv proj fashion
  540. (q_proj + k_proj + v_proj -> qkv_proj).
  541. This means we have 3 LoRAs, each applied to one slice of the layer.
  542. Q slice may have different shape than K and V slices (which both have
  543. the same shape).
  544. """
  545. def __init__(self, base_layer: QKVParallelLinear) -> None:
  546. super().__init__(base_layer)
  547. def create_lora_weights(
  548. self,
  549. max_loras: int,
  550. lora_config: LoRAConfig,
  551. model_config: Optional[PretrainedConfig] = None) -> None:
  552. self.tp_size = get_tensor_model_parallel_world_size()
  553. tp_rank = get_tensor_model_parallel_rank()
  554. self.q_proj_shard_size = (self.base_layer.num_heads *
  555. self.base_layer.head_size)
  556. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  557. self.base_layer.head_size)
  558. self.q_shard_id = tp_rank
  559. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  560. device = _get_lora_device(self.base_layer)
  561. # q, k, v
  562. self.lora_a_stacked = (
  563. torch.zeros(
  564. max_loras,
  565. 1,
  566. lora_config.max_lora_rank,
  567. self.base_layer.input_size,
  568. dtype=lora_config.lora_dtype,
  569. device=device,
  570. ),
  571. torch.zeros(
  572. max_loras,
  573. 1,
  574. lora_config.max_lora_rank,
  575. self.base_layer.input_size,
  576. dtype=lora_config.lora_dtype,
  577. device=device,
  578. ),
  579. torch.zeros(
  580. max_loras,
  581. 1,
  582. lora_config.max_lora_rank,
  583. self.base_layer.input_size,
  584. dtype=lora_config.lora_dtype,
  585. device=device,
  586. ),
  587. )
  588. self.lora_b_stacked = (
  589. torch.zeros(
  590. max_loras,
  591. 1,
  592. self.q_proj_shard_size,
  593. lora_config.max_lora_rank,
  594. dtype=lora_config.lora_dtype,
  595. device=device,
  596. ),
  597. torch.zeros(
  598. max_loras,
  599. 1,
  600. self.kv_proj_shard_size,
  601. lora_config.max_lora_rank,
  602. dtype=lora_config.lora_dtype,
  603. device=device,
  604. ),
  605. torch.zeros(
  606. max_loras,
  607. 1,
  608. self.kv_proj_shard_size,
  609. lora_config.max_lora_rank,
  610. dtype=lora_config.lora_dtype,
  611. device=device,
  612. ),
  613. )
  614. self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
  615. self.kv_proj_shard_size)
  616. self.packed_indices: Optional[torch.Tensor] = None
  617. self.standard_indices: Optional[torch.Tensor] = None
  618. self.indices_len: Optional[List[int]] = None
  619. def reset_lora(self, index: int):
  620. self.lora_a_stacked[0][index] = 0
  621. self.lora_b_stacked[0][index] = 0
  622. self.lora_a_stacked[1][index] = 0
  623. self.lora_b_stacked[1][index] = 0
  624. self.lora_a_stacked[2][index] = 0
  625. self.lora_b_stacked[2][index] = 0
  626. def set_lora(
  627. self,
  628. index: int,
  629. lora_a: torch.Tensor,
  630. lora_b: torch.Tensor,
  631. embeddings_tensor: Optional[torch.Tensor],
  632. ):
  633. self.reset_lora(index)
  634. if self.tp_size > 1:
  635. if lora_b[0] is not None:
  636. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  637. self.q_shard_id:self.q_proj_shard_size *
  638. (self.q_shard_id + 1)]
  639. self.lora_b_stacked[0][
  640. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  641. lora_b_q.T, non_blocking=True)
  642. if lora_b[1] is not None:
  643. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  644. self.kv_shard_id:self.kv_proj_shard_size *
  645. (self.kv_shard_id + 1)]
  646. self.lora_b_stacked[1][
  647. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  648. lora_b_k.T, non_blocking=True)
  649. if lora_b[2] is not None:
  650. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  651. self.kv_shard_id:self.kv_proj_shard_size *
  652. (self.kv_shard_id + 1)]
  653. self.lora_b_stacked[2][
  654. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  655. lora_b_v.T, non_blocking=True)
  656. else:
  657. if lora_b[0] is not None:
  658. self.lora_b_stacked[0][
  659. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  660. lora_b[0].T, non_blocking=True)
  661. if lora_b[1] is not None:
  662. self.lora_b_stacked[1][
  663. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  664. lora_b[1].T, non_blocking=True)
  665. if lora_b[2] is not None:
  666. self.lora_b_stacked[2][
  667. index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
  668. lora_b[2].T, non_blocking=True)
  669. if lora_a[0] is not None:
  670. self.lora_a_stacked[0][
  671. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  672. lora_a[0].T, non_blocking=True)
  673. if lora_a[1] is not None:
  674. self.lora_a_stacked[1][
  675. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  676. lora_a[1].T, non_blocking=True)
  677. if lora_a[2] is not None:
  678. self.lora_a_stacked[2][
  679. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  680. lora_a[2].T, non_blocking=True)
  681. def apply_weights(self, x: torch.Tensor,
  682. bias: Optional[torch.Tensor]) -> torch.Tensor:
  683. output = self.base_layer.linear_method.apply_weights(
  684. self.base_layer, x, bias)
  685. _apply_lora_packed_nslice(
  686. x,
  687. self.lora_a_stacked,
  688. self.lora_b_stacked,
  689. self.indices[:self.indices_len[0]],
  690. output,
  691. self.output_slices,
  692. )
  693. return output
  694. @classmethod
  695. def can_replace_layer(cls, source_layer: nn.Module,
  696. lora_config: LoRAConfig, packed_modules_list: List,
  697. model_config: Optional[PretrainedConfig]) -> bool:
  698. return type(source_layer) is QKVParallelLinear and len(
  699. packed_modules_list) == 3
  700. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  701. def __init__(self, base_layer: RowParallelLinear) -> None:
  702. super().__init__()
  703. self.base_layer = base_layer
  704. def create_lora_weights(
  705. self,
  706. max_loras: int,
  707. lora_config: LoRAConfig,
  708. model_config: Optional[PretrainedConfig] = None) -> None:
  709. device = _get_lora_device(self.base_layer)
  710. self.lora_a_stacked = torch.zeros(
  711. (
  712. max_loras,
  713. 1,
  714. lora_config.max_lora_rank,
  715. self.base_layer.input_size,
  716. ),
  717. dtype=lora_config.lora_dtype,
  718. device=device,
  719. )
  720. self.lora_b_stacked = torch.zeros(
  721. (
  722. max_loras,
  723. 1,
  724. self.base_layer.output_size,
  725. lora_config.max_lora_rank,
  726. ),
  727. dtype=lora_config.lora_dtype,
  728. device=device,
  729. )
  730. self.indices: Optional[torch.Tensor] = None
  731. self.indices_len: Optional[List[int]] = None
  732. def reset_lora(self, index: int):
  733. self.lora_a_stacked[index] = 0
  734. self.lora_b_stacked[index] = 0
  735. def set_lora(
  736. self,
  737. index: int,
  738. lora_a: torch.Tensor,
  739. lora_b: torch.Tensor,
  740. embeddings_tensor: Optional[torch.Tensor],
  741. ):
  742. self.reset_lora(index)
  743. if self.base_layer.tp_size > 1:
  744. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  745. shard_size = self.base_layer.weight.shape[1]
  746. start_idx = tensor_model_parallel_rank * shard_size
  747. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  748. lora_a = lora_a[start_idx:end_idx, :]
  749. self.lora_a_stacked[index,
  750. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  751. lora_a.T, non_blocking=True)
  752. self.lora_b_stacked[index,
  753. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  754. lora_b.T, non_blocking=True)
  755. def set_mapping(
  756. self,
  757. base_indices: torch.Tensor,
  758. sampler_indices: torch.Tensor,
  759. sampler_indices_padded: torch.Tensor,
  760. embeddings_indices: torch.Tensor,
  761. indices_len: List[int],
  762. ):
  763. self.indices = base_indices
  764. self.indices_len = indices_len
  765. def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
  766. output = self.base_layer.linear_method.apply_weights(
  767. self.base_layer, x)
  768. _apply_lora(
  769. x,
  770. self.lora_a_stacked,
  771. self.lora_b_stacked,
  772. self.indices[:self.indices_len[0]],
  773. output,
  774. )
  775. return output
  776. def forward(self, input_):
  777. """Forward of RowParallelLinear
  778. Args:
  779. input_: tensor whose last dimension is `input_size`. If
  780. `input_is_parallel` is set, then the last dimension
  781. is `input_size // tp_size`.
  782. Returns:
  783. - output
  784. - bias
  785. """
  786. # Set up backprop all-reduce.
  787. if self.base_layer.input_is_parallel:
  788. input_parallel = input_
  789. else:
  790. # TODO: simplify code below
  791. tp_rank = get_tensor_model_parallel_rank()
  792. splitted_input = split_tensor_along_last_dim(
  793. input_, num_partitions=self.base_layer.tp_size)
  794. input_parallel = splitted_input[tp_rank].contiguous()
  795. # Matrix multiply.
  796. output_parallel = self.apply_weights(input_parallel)
  797. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  798. output_ = tensor_model_parallel_all_reduce(output_parallel)
  799. else:
  800. output_ = output_parallel
  801. if not self.base_layer.skip_bias_add:
  802. output = (output_ + self.base_layer.bias
  803. if self.base_layer.bias is not None else output_)
  804. output_bias = None
  805. else:
  806. output = output_
  807. output_bias = self.base_layer.bias
  808. return output, output_bias
  809. @property
  810. def weight(self):
  811. return self.base_layer.weight
  812. @classmethod
  813. def can_replace_layer(cls, source_layer: nn.Module,
  814. lora_config: LoRAConfig, packed_modules_list: List,
  815. model_config: Optional[PretrainedConfig]) -> bool:
  816. return type(source_layer) is RowParallelLinear
  817. class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
  818. def __init__(
  819. self,
  820. base_layer: LogitsProcessor,
  821. hidden_size: int,
  822. dtype: torch.dtype,
  823. device: torch.device,
  824. ) -> None:
  825. super().__init__()
  826. self.base_layer = base_layer
  827. self.hidden_size = hidden_size
  828. self.dtype = dtype
  829. self.device = device
  830. @property
  831. def logits_as_input(self):
  832. return self.base_layer.logits_as_input
  833. @property
  834. def vocab_size(self):
  835. return self.base_layer.vocab_size
  836. @property
  837. def scale(self):
  838. return self.base_layer.scale
  839. @property
  840. def org_vocab_size(self):
  841. return self.base_layer.org_vocab_size
  842. @property
  843. def include_gpu_probs_tensor(self):
  844. return self.base_layer.include_gpu_probs_tensor
  845. def create_lora_weights(
  846. self,
  847. max_loras: int,
  848. lora_config: LoRAConfig,
  849. model_config: Optional[PretrainedConfig] = None,
  850. ) -> None:
  851. # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
  852. if 32000 < self.base_layer.vocab_size > 128512:
  853. raise ValueError("When using LoRA, vocab size must be "
  854. "32000 >= vocab_size <= 128512")
  855. self.lora_a_stacked = torch.zeros(
  856. (
  857. max_loras,
  858. 1,
  859. lora_config.max_lora_rank,
  860. self.hidden_size,
  861. ),
  862. dtype=lora_config.lora_dtype,
  863. device=self.device,
  864. )
  865. self.lora_b_stacked = torch.zeros(
  866. (
  867. max_loras,
  868. 1,
  869. # Pad for kernel compatibility
  870. math.ceil(self.base_layer.vocab_size /
  871. lora_config.lora_vocab_padding_size) *
  872. lora_config.lora_vocab_padding_size,
  873. lora_config.max_lora_rank,
  874. ),
  875. dtype=lora_config.lora_dtype,
  876. device=self.device,
  877. )
  878. self.embeddings_tensors = torch.full(
  879. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  880. fill_value=float("-inf"),
  881. dtype=self.dtype,
  882. device=self.device,
  883. )
  884. self.indices = None
  885. self.indices_padded = None
  886. self.indices_len = None
  887. def reset_lora(self, index: int):
  888. self.lora_a_stacked[index] = 0
  889. self.lora_b_stacked[index] = 0
  890. self.embeddings_tensors[index] = float("-inf")
  891. def set_lora(
  892. self,
  893. index: int,
  894. lora_a: torch.Tensor,
  895. lora_b: torch.Tensor,
  896. embeddings_tensor: Optional[torch.Tensor],
  897. ):
  898. self.reset_lora(index)
  899. self.lora_a_stacked[index,
  900. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  901. lora_a.T, non_blocking=True)
  902. self.lora_b_stacked[index,
  903. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  904. lora_b.T, non_blocking=True)
  905. if embeddings_tensor is not None:
  906. self.embeddings_tensors[
  907. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  908. shape[1], ] = embeddings_tensor
  909. def set_mapping(
  910. self,
  911. base_indices: torch.Tensor,
  912. sampler_indices: torch.Tensor,
  913. sampler_indices_padded: torch.Tensor,
  914. embeddings_indices: torch.Tensor,
  915. indices_len: List[int],
  916. ):
  917. self.indices = sampler_indices
  918. self.indices_padded = sampler_indices_padded
  919. self.indices_len = indices_len
  920. def _get_logits(
  921. self,
  922. hidden_states: torch.Tensor,
  923. lm_head: torch.Tensor,
  924. embedding_bias: Optional[torch.Tensor] = None,
  925. ) -> Optional[torch.Tensor]:
  926. # Get the logits for the next tokens.
  927. logits = lm_head(hidden_states)
  928. if embedding_bias is not None:
  929. logits += embedding_bias
  930. logits = tensor_model_parallel_gather(logits)
  931. if logits is None:
  932. return None
  933. lora_logits = torch.empty(
  934. self.embeddings_tensors.shape[0] + 1,
  935. self.embeddings_tensors.shape[1],
  936. hidden_states.shape[0],
  937. dtype=self.embeddings_tensors.dtype,
  938. device=self.embeddings_tensors.device,
  939. )
  940. torch.matmul(self.embeddings_tensors,
  941. hidden_states.T,
  942. out=lora_logits[:-1])
  943. lora_logits[-1] = float("-inf")
  944. lora_logits = lora_logits.mT
  945. lora_logits = (lora_logits.reshape(
  946. lora_logits.shape[0] * lora_logits.shape[1],
  947. lora_logits.shape[2],
  948. ).index_select(0,
  949. self.indices_padded[:self.indices_len[2]]).nan_to_num_(
  950. nan=float("-inf"),
  951. posinf=float("inf"),
  952. neginf=float("-inf")))
  953. logits[:,
  954. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  955. lora_logits.shape[1]] = lora_logits
  956. _apply_lora(
  957. hidden_states,
  958. self.lora_a_stacked,
  959. self.lora_b_stacked,
  960. self.indices[:self.indices_len[1]],
  961. logits,
  962. )
  963. # Remove paddings in vocab (if any).
  964. logits = logits[:, :self.base_layer.vocab_size]
  965. return logits
  966. def forward(self, *args, **kwargs):
  967. return type(self.base_layer).forward(self, *args, **kwargs)
  968. @classmethod
  969. def can_replace_layer(cls, source_layer: nn.Module,
  970. lora_config: LoRAConfig, packed_modules_list: List,
  971. model_config: Optional[PretrainedConfig]) -> bool:
  972. # Special handling for the LogitsProcessor.
  973. return False
  974. _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
  975. cls
  976. for cls in globals().values() if inspect.isclass(cls)
  977. and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
  978. }
  979. def from_layer(layer: nn.Module,
  980. max_loras: int,
  981. lora_config: LoRAConfig,
  982. packed_modules_list: List,
  983. model_config: Optional[PretrainedConfig] = None) -> nn.Module:
  984. for lora_cls in _all_lora_classes:
  985. if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
  986. model_config):
  987. ret = lora_cls(layer)
  988. ret.create_lora_weights(max_loras, lora_config, model_config)
  989. return ret
  990. return layer
  991. def from_layer_logits_processor(
  992. layer: LogitsProcessor,
  993. lm_head: ParallelLMHead,
  994. max_loras: int,
  995. lora_config: LoRAConfig,
  996. model_config: Optional[PretrainedConfig] = None,
  997. ) -> LogitsProcessorWithLoRA:
  998. ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
  999. lm_head.weight.dtype, lm_head.weight.device)
  1000. ret.create_lora_weights(max_loras, lora_config, model_config)
  1001. return ret