1
0

layers.py 35 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004
  1. # pylint: disable=unused-argument
  2. import math
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from transformers import PretrainedConfig
  9. from aphrodite.common.config import LoRAConfig
  10. from aphrodite.lora.punica import add_lora, add_lora_slice, bgmv
  11. from aphrodite.modeling.layers.sampler import Sampler
  12. from aphrodite.modeling.megatron.communication_op import (
  13. tensor_model_parallel_all_gather,
  14. tensor_model_parallel_all_reduce,
  15. tensor_model_parallel_gather,
  16. )
  17. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  18. RowParallelLinear,
  19. QKVParallelLinear,
  20. MergedColumnParallelLinear)
  21. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  22. VocabParallelEmbedding, ParallelLMHead)
  23. from aphrodite.modeling.megatron.parallel_state import (
  24. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  25. from aphrodite.modeling.megatron.utils import split_tensor_along_last_dim
  26. if TYPE_CHECKING:
  27. pass
  28. def _get_lora_device(base_layer: nn.Module) -> torch.device:
  29. """Identify the device for positioning the LoRA tensors."""
  30. device = None
  31. try:
  32. device = base_layer.weight.device
  33. except AttributeError:
  34. try:
  35. linear_weights = base_layer.linear_weights
  36. if isinstance(linear_weights, dict):
  37. tensor_values = [
  38. v for v in linear_weights.values()
  39. if isinstance(v, torch.Tensor)
  40. ]
  41. if tensor_values:
  42. device = tensor_values[0].device
  43. except AttributeError:
  44. pass
  45. if device is None:
  46. raise ValueError(f"Base layer not supported: {base_layer}")
  47. return device
  48. def _apply_lora(
  49. x: torch.Tensor,
  50. lora_a_stacked: torch.Tensor,
  51. lora_b_stacked: torch.Tensor,
  52. indices: torch.Tensor,
  53. output: torch.Tensor,
  54. ):
  55. """Applies lora to each input.
  56. This method applies all loras to each input. It uses the
  57. indices vector to determine which lora yields the
  58. correct output. An index of -1 means no lora should be
  59. applied. This method adds the final lora results to the
  60. output.
  61. Input shapes:
  62. x: (batch_size, hidden_dim)
  63. lora_a_stacked: (num_loras, lora_rank, hidden_dim)
  64. lora_b_stacked: (num_loras, output_dim, lora_rank)
  65. indices: (batch_size)
  66. output: (batch_size, output_dim)
  67. """
  68. org_output = output
  69. x = x.view(-1, x.shape[-1])
  70. output = output.view(-1, output.shape[-1])
  71. indices = indices.view(-1)
  72. add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
  73. return output.view_as(org_output)
  74. def _apply_lora_packed_nslice(
  75. x: torch.Tensor,
  76. lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  77. lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  78. indices: torch.Tensor,
  79. output: torch.Tensor,
  80. output_slices: Tuple[int, ...],
  81. ):
  82. """Applies lora to each input.
  83. This method applies all loras to each input. It uses the
  84. indices vector to determine which lora yields the
  85. correct output. An index of -1 means no lora should be
  86. applied. This method adds the final lora results to the
  87. output.
  88. This method is used for layers that are composed of multiple sublayers
  89. (slices) packed together.
  90. Input shapes:
  91. x: (batch_size, hidden_dim)
  92. lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
  93. lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
  94. indices: (batch_size)
  95. output: (batch_size, q_slice_size + 2*kv_slice_size)
  96. output_slices: n-1 element tuple of (slice_size...), where n is
  97. number of slices
  98. """
  99. org_output = output
  100. x = x.view(-1, x.shape[-1])
  101. output = output.view(-1, output.shape[-1])
  102. indices = indices.view(-1)
  103. offset_left = 0
  104. for slice_idx in range(len(output_slices)):
  105. add_lora_slice(output, x, lora_a_stacked[slice_idx],
  106. lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
  107. output_slices[slice_idx])
  108. offset_left += output_slices[slice_idx]
  109. return output.view_as(org_output)
  110. @dataclass
  111. class LoRAMapping:
  112. # Per every token in input_ids:
  113. index_mapping: Tuple[int, ...]
  114. # Per sampled token:
  115. prompt_mapping: Tuple[int, ...]
  116. def __post_init__(self):
  117. self.index_mapping = tuple(self.index_mapping)
  118. self.prompt_mapping = tuple(self.prompt_mapping)
  119. class BaseLayerWithLoRA(nn.Module):
  120. def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
  121. model_config: PretrainedConfig) -> None:
  122. """Initializes lora matrices."""
  123. ...
  124. def reset_lora(self, index: int):
  125. """Resets the lora weights at index back to 0."""
  126. ...
  127. def set_lora(
  128. self,
  129. index: int,
  130. lora_a: torch.Tensor,
  131. lora_b: torch.Tensor,
  132. embeddings_tensor: Optional[torch.Tensor],
  133. ):
  134. """Overwrites lora tensors at index."""
  135. ...
  136. def set_mapping(
  137. self,
  138. base_indices: torch.Tensor,
  139. sampler_indices: torch.Tensor,
  140. sampler_indices_padded: torch.Tensor,
  141. embeddings_indices: torch.Tensor,
  142. indices_len: List[int],
  143. ):
  144. """Sets the mapping indices."""
  145. ...
  146. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
  147. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  148. super().__init__()
  149. self.base_layer = base_layer
  150. def create_lora_weights(
  151. self,
  152. max_loras: int,
  153. lora_config: LoRAConfig,
  154. model_config: Optional[PretrainedConfig] = None) -> None:
  155. lora_vocab_start_idx = self.base_layer.org_vocab_size
  156. weights_idx = None
  157. if self.base_layer.vocab_end_index > lora_vocab_start_idx:
  158. # We can start adding lora weights
  159. weights_idx = max(
  160. lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
  161. self.embeddings_slice = (self.base_layer.vocab_start_index -
  162. self.base_layer.org_vocab_size +
  163. weights_idx,
  164. self.base_layer.vocab_end_index -
  165. self.base_layer.org_vocab_size)
  166. self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
  167. self.embeddings_weights.fill_(0)
  168. else:
  169. self.embeddings_slice = None
  170. self.embeddings_weights = None
  171. self.embeddings_tensors = torch.zeros(
  172. (
  173. max_loras,
  174. lora_config.lora_extra_vocab_size,
  175. self.base_layer.embedding_dim,
  176. ),
  177. dtype=self.base_layer.weight.dtype,
  178. device=self.base_layer.weight.device,
  179. )
  180. self.lora_a_stacked = torch.zeros(
  181. (
  182. max_loras,
  183. self.base_layer.org_vocab_size +
  184. lora_config.lora_extra_vocab_size,
  185. lora_config.max_lora_rank,
  186. ),
  187. dtype=lora_config.lora_dtype,
  188. device=self.base_layer.weight.device,
  189. )
  190. self.lora_b_stacked = torch.zeros(
  191. (
  192. max_loras,
  193. 1,
  194. self.base_layer.embedding_dim,
  195. lora_config.max_lora_rank,
  196. ),
  197. dtype=lora_config.lora_dtype,
  198. device=self.base_layer.weight.device,
  199. )
  200. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  201. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  202. self.lora_a_stacked.shape[2],
  203. )
  204. self.indices: Optional[torch.Tensor] = None
  205. self.indices_len: Optional[List[int]] = None
  206. self.embeddings_indices = None
  207. def reset_lora(self, index: int):
  208. self.lora_a_stacked[index] = 0
  209. self.lora_b_stacked[index] = 0
  210. self.embeddings_tensors[index] = 0
  211. def set_lora(
  212. self,
  213. index: int,
  214. lora_a: torch.Tensor,
  215. lora_b: torch.Tensor,
  216. embeddings_tensor: Optional[torch.Tensor],
  217. ):
  218. self.reset_lora(index)
  219. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  220. lora_a, non_blocking=True)
  221. self.lora_b_stacked[index,
  222. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  223. lora_b.T, non_blocking=True)
  224. if embeddings_tensor is not None:
  225. self.embeddings_tensors[
  226. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  227. shape[1]].copy_(embeddings_tensor, non_blocking=True)
  228. if self.embeddings_slice is not None:
  229. # TODO(yard1): Optimize this copy, we don't need to copy
  230. # everything, just the modified part
  231. embeddings = self.embeddings_tensors.view(
  232. self.embeddings_tensors.shape[0] *
  233. self.embeddings_tensors.shape[1],
  234. self.embeddings_tensors.shape[2]
  235. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  236. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  237. def set_mapping(
  238. self,
  239. base_indices: torch.Tensor,
  240. sampler_indices: torch.Tensor,
  241. sampler_indices_padded: torch.Tensor,
  242. embeddings_indices: torch.Tensor,
  243. indices_len: List[int],
  244. ):
  245. self.indices = base_indices
  246. self.embeddings_indices = embeddings_indices
  247. self.indices_len = indices_len
  248. def forward(self, x: torch.Tensor) -> torch.Tensor:
  249. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  250. indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
  251. full_lora_a_embeddings = F.embedding(
  252. x + indices,
  253. self.lora_a_stacked_2d,
  254. )
  255. indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
  256. full_output = self.base_layer.forward(
  257. x.add_(indices * added_tokens_mask))
  258. full_output_org = full_output
  259. if full_output.ndim == 3:
  260. full_output = full_output.view(
  261. full_output.shape[0] * full_output.shape[1], -1)
  262. if full_lora_a_embeddings.ndim == 3:
  263. full_lora_a_embeddings = full_lora_a_embeddings.view(
  264. full_lora_a_embeddings.shape[0] *
  265. full_lora_a_embeddings.shape[1], -1)
  266. bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
  267. self.indices[:self.indices_len[0]], 0, 1.0)
  268. return full_output.view_as(full_output_org)
  269. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  270. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  271. super().__init__()
  272. self.base_layer = base_layer
  273. def create_lora_weights(
  274. self,
  275. max_loras: int,
  276. lora_config: LoRAConfig,
  277. model_config: Optional[PretrainedConfig] = None) -> None:
  278. self.lora_a_stacked = torch.zeros(
  279. max_loras,
  280. 1,
  281. lora_config.max_lora_rank,
  282. self.base_layer.weight.shape[1],
  283. dtype=lora_config.lora_dtype,
  284. device=self.base_layer.weight.device,
  285. )
  286. self.lora_b_stacked = torch.zeros(
  287. max_loras,
  288. 1,
  289. self.base_layer.weight.shape[0],
  290. lora_config.max_lora_rank,
  291. dtype=lora_config.lora_dtype,
  292. device=self.base_layer.weight.device,
  293. )
  294. self.indices: Optional[torch.Tensor] = None
  295. self.indices_len: Optional[List[int]] = None
  296. self.output_dim = self.lora_b_stacked.shape[1]
  297. def reset_lora(self, index: int):
  298. self.lora_a_stacked[index] = 0
  299. self.lora_b_stacked[index] = 0
  300. def set_lora(
  301. self,
  302. index: int,
  303. lora_a: torch.Tensor,
  304. lora_b: torch.Tensor,
  305. embeddings_tensor: Optional[torch.Tensor],
  306. ):
  307. self.reset_lora(index)
  308. self.lora_a_stacked[index,
  309. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  310. lora_a.T, non_blocking=True)
  311. self.lora_b_stacked[index,
  312. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  313. lora_b.T, non_blocking=True)
  314. def set_mapping(
  315. self,
  316. base_indices: torch.Tensor,
  317. sampler_indices: torch.Tensor,
  318. sampler_indices_padded: torch.Tensor,
  319. embeddings_indices: torch.Tensor,
  320. indices_len: List[int],
  321. ):
  322. self.indices = base_indices
  323. self.indices_len = indices_len
  324. def apply_weights(self, x: torch.Tensor,
  325. bias: Optional[torch.Tensor]) -> torch.Tensor:
  326. output = self.base_layer.linear_method.apply_weights(
  327. self.base_layer.linear_weights, x, bias)
  328. _apply_lora(
  329. x,
  330. self.lora_a_stacked,
  331. self.lora_b_stacked,
  332. self.indices[:self.indices_len[0]],
  333. output,
  334. )
  335. return output
  336. def forward(self, input_):
  337. """Forward of ColumnParallelLinear
  338. Args:
  339. input_: Tensor whose last dimension is `input_size`.
  340. Returns:
  341. - output
  342. - bias
  343. """
  344. bias = (self.base_layer.bias
  345. if not self.base_layer.skip_bias_add else None)
  346. # Matrix multiply.
  347. output_parallel = self.apply_weights(input_, bias)
  348. if self.base_layer.gather_output:
  349. # All-gather across the partitions.
  350. output = tensor_model_parallel_all_gather(output_parallel)
  351. else:
  352. output = output_parallel
  353. output_bias = (self.base_layer.bias
  354. if self.base_layer.skip_bias_add else None)
  355. return output, output_bias
  356. @property
  357. def linear_weights(self):
  358. return self.base_layer.linear_weights
  359. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  360. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  361. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  362. This means we have 2 LoRAs, each applied to one half of the layer.
  363. Both slices must have the same size.
  364. """
  365. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  366. super().__init__(base_layer)
  367. def create_lora_weights(
  368. self,
  369. max_loras: int,
  370. lora_config: LoRAConfig,
  371. model_config: Optional[PretrainedConfig] = None) -> None:
  372. n_slices = 2
  373. if not (len(self.base_layer.output_sizes) == n_slices
  374. and self.base_layer.output_sizes[0]
  375. == self.base_layer.output_sizes[1]):
  376. raise ValueError(
  377. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  378. "the same size.")
  379. self.tp_size = get_tensor_model_parallel_world_size()
  380. device = _get_lora_device(self.base_layer)
  381. self.lora_a_stacked = tuple(
  382. torch.zeros(
  383. max_loras,
  384. 1,
  385. lora_config.max_lora_rank,
  386. self.base_layer.input_size,
  387. dtype=lora_config.lora_dtype,
  388. device=device,
  389. ) for _ in range(n_slices))
  390. self.lora_b_stacked = tuple(
  391. torch.zeros(
  392. max_loras,
  393. 1,
  394. self.base_layer.output_size // 2,
  395. lora_config.max_lora_rank,
  396. dtype=lora_config.lora_dtype,
  397. device=device,
  398. ) for _ in range(n_slices))
  399. self.indices: Optional[torch.Tensor] = None
  400. self.output_dim = self.lora_b_stacked[0].shape[2]
  401. def reset_lora(self, index: int):
  402. self.lora_a_stacked[0][index] = 0
  403. self.lora_a_stacked[1][index] = 0
  404. self.lora_b_stacked[0][index] = 0
  405. self.lora_b_stacked[1][index] = 0
  406. def set_lora(
  407. self,
  408. index: int,
  409. lora_a: torch.Tensor,
  410. lora_b: torch.Tensor,
  411. embeddings_tensor: Optional[torch.Tensor],
  412. ):
  413. self.reset_lora(index)
  414. if self.tp_size > 1:
  415. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  416. shard_size = self.output_dim
  417. start_idx = tensor_model_parallel_rank * shard_size
  418. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  419. lora_b = lora_b[0][:,
  420. start_idx:end_idx], lora_b[1][:,
  421. start_idx:end_idx]
  422. if lora_a[0] is not None:
  423. self.lora_a_stacked[0][
  424. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  425. lora_a[0].T, non_blocking=True)
  426. self.lora_b_stacked[0][
  427. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  428. lora_b[0].T, non_blocking=True)
  429. if lora_a[1] is not None:
  430. self.lora_a_stacked[1][
  431. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  432. lora_a[1].T, non_blocking=True)
  433. self.lora_b_stacked[1][
  434. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  435. lora_b[1].T, non_blocking=True)
  436. def apply_weights(self, x: torch.Tensor,
  437. bias: Optional[torch.Tensor]) -> torch.Tensor:
  438. output = self.base_layer.linear_method.apply_weights(
  439. self.base_layer.linear_weights, x, bias)
  440. _apply_lora_packed_nslice(
  441. x,
  442. self.lora_a_stacked,
  443. self.lora_b_stacked,
  444. self.indices[:self.indices_len[0]],
  445. output,
  446. (self.output_dim, self.output_dim),
  447. )
  448. return output
  449. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  450. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  451. packed together in qkv proj fashion
  452. (q_proj + k_proj + v_proj -> qkv_proj).
  453. This means we have 3 LoRAs, each applied to one slice of the layer.
  454. Q slice may have different shape than K and V slices (which both have
  455. the same shape).
  456. """
  457. def __init__(self, base_layer: QKVParallelLinear) -> None:
  458. super().__init__(base_layer)
  459. def create_lora_weights(
  460. self,
  461. max_loras: int,
  462. lora_config: LoRAConfig,
  463. model_config: Optional[PretrainedConfig] = None) -> None:
  464. self.tp_size = get_tensor_model_parallel_world_size()
  465. tp_rank = get_tensor_model_parallel_rank()
  466. self.q_proj_shard_size = (self.base_layer.num_heads *
  467. self.base_layer.head_size)
  468. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  469. self.base_layer.head_size)
  470. self.q_shard_id = tp_rank
  471. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  472. device = _get_lora_device(self.base_layer)
  473. # q, k, v
  474. self.lora_a_stacked = (
  475. torch.zeros(
  476. max_loras,
  477. 1,
  478. lora_config.max_lora_rank,
  479. self.base_layer.input_size,
  480. dtype=lora_config.lora_dtype,
  481. device=device,
  482. ),
  483. torch.zeros(
  484. max_loras,
  485. 1,
  486. lora_config.max_lora_rank,
  487. self.base_layer.input_size,
  488. dtype=lora_config.lora_dtype,
  489. device=device,
  490. ),
  491. torch.zeros(
  492. max_loras,
  493. 1,
  494. lora_config.max_lora_rank,
  495. self.base_layer.input_size,
  496. dtype=lora_config.lora_dtype,
  497. device=device,
  498. ),
  499. )
  500. self.lora_b_stacked = (
  501. torch.zeros(
  502. max_loras,
  503. 1,
  504. self.q_proj_shard_size,
  505. lora_config.max_lora_rank,
  506. dtype=lora_config.lora_dtype,
  507. device=device,
  508. ),
  509. torch.zeros(
  510. max_loras,
  511. 1,
  512. self.kv_proj_shard_size,
  513. lora_config.max_lora_rank,
  514. dtype=lora_config.lora_dtype,
  515. device=device,
  516. ),
  517. torch.zeros(
  518. max_loras,
  519. 1,
  520. self.kv_proj_shard_size,
  521. lora_config.max_lora_rank,
  522. dtype=lora_config.lora_dtype,
  523. device=device,
  524. ),
  525. )
  526. self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
  527. self.kv_proj_shard_size)
  528. self.packed_indices: Optional[torch.Tensor] = None
  529. self.standard_indices: Optional[torch.Tensor] = None
  530. self.indices_len: Optional[List[int]] = None
  531. def reset_lora(self, index: int):
  532. self.lora_a_stacked[0][index] = 0
  533. self.lora_b_stacked[0][index] = 0
  534. self.lora_a_stacked[1][index] = 0
  535. self.lora_b_stacked[1][index] = 0
  536. self.lora_a_stacked[2][index] = 0
  537. self.lora_b_stacked[2][index] = 0
  538. def set_lora(
  539. self,
  540. index: int,
  541. lora_a: torch.Tensor,
  542. lora_b: torch.Tensor,
  543. embeddings_tensor: Optional[torch.Tensor],
  544. ):
  545. self.reset_lora(index)
  546. if self.tp_size > 1:
  547. if lora_b[0] is not None:
  548. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  549. self.q_shard_id:self.q_proj_shard_size *
  550. (self.q_shard_id + 1)]
  551. self.lora_b_stacked[0][
  552. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  553. lora_b_q.T, non_blocking=True)
  554. if lora_b[1] is not None:
  555. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  556. self.kv_shard_id:self.kv_proj_shard_size *
  557. (self.kv_shard_id + 1)]
  558. self.lora_b_stacked[1][
  559. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  560. lora_b_k.T, non_blocking=True)
  561. if lora_b[2] is not None:
  562. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  563. self.kv_shard_id:self.kv_proj_shard_size *
  564. (self.kv_shard_id + 1)]
  565. self.lora_b_stacked[2][
  566. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  567. lora_b_v.T, non_blocking=True)
  568. else:
  569. if lora_b[0] is not None:
  570. self.lora_b_stacked[0][
  571. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  572. lora_b[0].T, non_blocking=True)
  573. if lora_b[1] is not None:
  574. self.lora_b_stacked[1][
  575. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  576. lora_b[1].T, non_blocking=True)
  577. if lora_b[2] is not None:
  578. self.lora_b_stacked[2][
  579. index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
  580. lora_b[2].T, non_blocking=True)
  581. if lora_a[0] is not None:
  582. self.lora_a_stacked[0][
  583. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  584. lora_a[0].T, non_blocking=True)
  585. if lora_a[1] is not None:
  586. self.lora_a_stacked[1][
  587. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  588. lora_a[1].T, non_blocking=True)
  589. if lora_a[2] is not None:
  590. self.lora_a_stacked[2][
  591. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  592. lora_a[2].T, non_blocking=True)
  593. def apply_weights(self, x: torch.Tensor,
  594. bias: Optional[torch.Tensor]) -> torch.Tensor:
  595. output = self.base_layer.linear_method.apply_weights(
  596. self.base_layer.linear_weights, x, bias)
  597. _apply_lora_packed_nslice(
  598. x,
  599. self.lora_a_stacked,
  600. self.lora_b_stacked,
  601. self.indices[:self.indices_len[0]],
  602. output,
  603. self.output_slices,
  604. )
  605. return output
  606. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  607. def __init__(self, base_layer: RowParallelLinear) -> None:
  608. super().__init__()
  609. self.base_layer = base_layer
  610. def create_lora_weights(
  611. self,
  612. max_loras: int,
  613. lora_config: LoRAConfig,
  614. model_config: Optional[PretrainedConfig] = None) -> None:
  615. device = _get_lora_device(self.base_layer)
  616. self.lora_a_stacked = torch.zeros(
  617. (
  618. max_loras,
  619. 1,
  620. lora_config.max_lora_rank,
  621. self.base_layer.input_size,
  622. ),
  623. dtype=lora_config.lora_dtype,
  624. device=device,
  625. )
  626. self.lora_b_stacked = torch.zeros(
  627. (
  628. max_loras,
  629. 1,
  630. self.base_layer.output_size,
  631. lora_config.max_lora_rank,
  632. ),
  633. dtype=lora_config.lora_dtype,
  634. device=device,
  635. )
  636. self.indices: Optional[torch.Tensor] = None
  637. self.indices_len: Optional[List[int]] = None
  638. def reset_lora(self, index: int):
  639. self.lora_a_stacked[index] = 0
  640. self.lora_b_stacked[index] = 0
  641. def set_lora(
  642. self,
  643. index: int,
  644. lora_a: torch.Tensor,
  645. lora_b: torch.Tensor,
  646. embeddings_tensor: Optional[torch.Tensor],
  647. ):
  648. self.reset_lora(index)
  649. if self.base_layer.tp_size > 1:
  650. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  651. shard_size = self.base_layer.weight.shape[1]
  652. start_idx = tensor_model_parallel_rank * shard_size
  653. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  654. lora_a = lora_a[start_idx:end_idx, :]
  655. self.lora_a_stacked[index,
  656. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  657. lora_a.T, non_blocking=True)
  658. self.lora_b_stacked[index,
  659. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  660. lora_b.T, non_blocking=True)
  661. def set_mapping(
  662. self,
  663. base_indices: torch.Tensor,
  664. sampler_indices: torch.Tensor,
  665. sampler_indices_padded: torch.Tensor,
  666. embeddings_indices: torch.Tensor,
  667. indices_len: List[int],
  668. ):
  669. self.indices = base_indices
  670. self.indices_len = indices_len
  671. def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
  672. output = self.base_layer.linear_method.apply_weights(
  673. self.base_layer.linear_weights, x)
  674. _apply_lora(
  675. x,
  676. self.lora_a_stacked,
  677. self.lora_b_stacked,
  678. self.indices[:self.indices_len[0]],
  679. output,
  680. )
  681. return output
  682. def forward(self, input_):
  683. """Forward of RowParallelLinear
  684. Args:
  685. input_: tensor whose last dimension is `input_size`. If
  686. `input_is_parallel` is set, then the last dimension
  687. is `input_size // tp_size`.
  688. Returns:
  689. - output
  690. - bias
  691. """
  692. # Set up backprop all-reduce.
  693. if self.base_layer.input_is_parallel:
  694. input_parallel = input_
  695. else:
  696. # TODO: simplify code below
  697. tp_rank = get_tensor_model_parallel_rank()
  698. splitted_input = split_tensor_along_last_dim(
  699. input_, num_partitions=self.base_layer.tp_size)
  700. input_parallel = splitted_input[tp_rank].contiguous()
  701. # Matrix multiply.
  702. output_parallel = self.apply_weights(input_parallel)
  703. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  704. output_ = tensor_model_parallel_all_reduce(output_parallel)
  705. else:
  706. output_ = output_parallel
  707. if not self.base_layer.skip_bias_add:
  708. output = (output_ + self.base_layer.bias
  709. if self.base_layer.bias is not None else output_)
  710. output_bias = None
  711. else:
  712. output = output_
  713. output_bias = self.base_layer.bias
  714. return output, output_bias
  715. @property
  716. def weight(self):
  717. return self.base_layer.weight
  718. class SamplerWithLoRA(BaseLayerWithLoRA):
  719. def __init__(
  720. self,
  721. base_layer: Sampler,
  722. hidden_size: int,
  723. dtype: torch.dtype,
  724. device: torch.device,
  725. ) -> None:
  726. super().__init__()
  727. self.base_layer = base_layer
  728. self.hidden_size = hidden_size
  729. self.dtype = dtype
  730. self.device = device
  731. @property
  732. def vocab_size(self):
  733. return self.base_layer.vocab_size
  734. @property
  735. def org_vocab_size(self):
  736. return self.base_layer.org_vocab_size
  737. @property
  738. def include_gpu_probs_tensor(self):
  739. return self.base_layer.include_gpu_probs_tensor
  740. def create_lora_weights(
  741. self,
  742. max_loras: int,
  743. lora_config: LoRAConfig,
  744. model_config: Optional[PretrainedConfig] = None,
  745. ) -> None:
  746. # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
  747. if 32000 < self.base_layer.vocab_size > 33024:
  748. raise ValueError(
  749. "When using LoRA, vocab size must be 32000 >= vocab_size "
  750. "<= 33024")
  751. self.lora_a_stacked = torch.zeros(
  752. (
  753. max_loras,
  754. 1,
  755. lora_config.max_lora_rank,
  756. self.hidden_size,
  757. ),
  758. dtype=lora_config.lora_dtype,
  759. device=self.device,
  760. )
  761. self.lora_b_stacked = torch.zeros(
  762. (
  763. max_loras,
  764. 1,
  765. # Pad for kernel compatibility
  766. math.ceil(self.base_layer.vocab_size /
  767. lora_config.lora_vocab_padding_size) *
  768. lora_config.lora_vocab_padding_size,
  769. lora_config.max_lora_rank,
  770. ),
  771. dtype=lora_config.lora_dtype,
  772. device=self.device,
  773. )
  774. self.embeddings_tensors = torch.full(
  775. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  776. fill_value=float("-inf"),
  777. dtype=self.dtype,
  778. device=self.device,
  779. )
  780. self.indices = None
  781. self.indices_padded = None
  782. self.indices_len = None
  783. def reset_lora(self, index: int):
  784. self.lora_a_stacked[index] = 0
  785. self.lora_b_stacked[index] = 0
  786. self.embeddings_tensors[index] = float("-inf")
  787. def set_lora(
  788. self,
  789. index: int,
  790. lora_a: torch.Tensor,
  791. lora_b: torch.Tensor,
  792. embeddings_tensor: Optional[torch.Tensor],
  793. ):
  794. self.reset_lora(index)
  795. self.lora_a_stacked[index,
  796. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  797. lora_a.T, non_blocking=True)
  798. self.lora_b_stacked[index,
  799. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  800. lora_b.T, non_blocking=True)
  801. if embeddings_tensor is not None:
  802. self.embeddings_tensors[
  803. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  804. shape[1], ] = embeddings_tensor
  805. def set_mapping(
  806. self,
  807. base_indices: torch.Tensor,
  808. sampler_indices: torch.Tensor,
  809. sampler_indices_padded: torch.Tensor,
  810. embeddings_indices: torch.Tensor,
  811. indices_len: List[int],
  812. ):
  813. self.indices = sampler_indices
  814. self.indices_padded = sampler_indices_padded
  815. self.indices_len = indices_len
  816. def _get_logits(
  817. self,
  818. hidden_states: torch.Tensor,
  819. embedding: torch.Tensor,
  820. embedding_bias: Optional[torch.Tensor] = None,
  821. ) -> torch.Tensor:
  822. # Get the logits for the next tokens.
  823. logits = torch.matmul(hidden_states, embedding.t())
  824. if embedding_bias is not None:
  825. logits += embedding_bias
  826. logits = tensor_model_parallel_gather(logits)
  827. if logits is None:
  828. return None
  829. lora_logits = torch.empty(
  830. self.embeddings_tensors.shape[0] + 1,
  831. self.embeddings_tensors.shape[1],
  832. hidden_states.shape[0],
  833. dtype=self.embeddings_tensors.dtype,
  834. device=self.embeddings_tensors.device,
  835. )
  836. torch.matmul(self.embeddings_tensors,
  837. hidden_states.T,
  838. out=lora_logits[:-1])
  839. lora_logits[-1] = float("-inf")
  840. lora_logits = lora_logits.mT
  841. lora_logits = (lora_logits.reshape(
  842. lora_logits.shape[0] * lora_logits.shape[1],
  843. lora_logits.shape[2],
  844. ).index_select(0,
  845. self.indices_padded[:self.indices_len[2]]).nan_to_num_(
  846. nan=float("-inf"),
  847. posinf=float("inf"),
  848. neginf=float("-inf")))
  849. logits[:,
  850. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  851. lora_logits.shape[1]] = lora_logits
  852. _apply_lora(
  853. hidden_states,
  854. self.lora_a_stacked,
  855. self.lora_b_stacked,
  856. self.indices[:self.indices_len[1]],
  857. logits,
  858. )
  859. # Remove paddings in vocab (if any).
  860. logits = logits[:, :self.base_layer.vocab_size]
  861. return logits
  862. def forward(self, *args, **kwargs):
  863. return type(self.base_layer).forward(self, *args, **kwargs)
  864. def from_layer(
  865. layer: nn.Module,
  866. max_loras: int,
  867. lora_config: LoRAConfig,
  868. model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
  869. supported_layer_types = {
  870. VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
  871. ColumnParallelLinear: ColumnParallelLinearWithLoRA,
  872. QKVParallelLinear: QKVParallelLinearWithLora,
  873. MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
  874. RowParallelLinear: RowParallelLinearWithLoRA,
  875. }
  876. for src_layer_type, lora_layer_type in supported_layer_types.items():
  877. if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
  878. ret = lora_layer_type(layer)
  879. ret.create_lora_weights(max_loras, lora_config, model_config)
  880. return ret
  881. return layer
  882. def from_layer_sampler(
  883. layer: Sampler,
  884. lm_head: ParallelLMHead,
  885. max_loras: int,
  886. lora_config: LoRAConfig,
  887. model_config: Optional[PretrainedConfig] = None,
  888. ) -> SamplerWithLoRA:
  889. ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
  890. lm_head.weight.device)
  891. ret.create_lora_weights(max_loras, lora_config, model_config)
  892. return ret