layers.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002
  1. # pylint: disable=unused-argument
  2. import math
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from transformers import PretrainedConfig
  9. from aphrodite.common.config import LoRAConfig
  10. from aphrodite.lora.punica import add_lora, add_lora_slice, bgmv
  11. from aphrodite.modeling.layers.sampler import Sampler
  12. from aphrodite.modeling.megatron.communication_op import (
  13. tensor_model_parallel_all_gather,
  14. tensor_model_parallel_all_reduce,
  15. tensor_model_parallel_gather,
  16. )
  17. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  18. RowParallelLinear,
  19. QKVParallelLinear,
  20. MergedColumnParallelLinear)
  21. from aphrodite.modeling.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
  22. from aphrodite.modeling.megatron.parallel_state import (
  23. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  24. from aphrodite.modeling.megatron.utils import split_tensor_along_last_dim
  25. if TYPE_CHECKING:
  26. pass
  27. def _get_lora_device(base_layer: nn.Module) -> torch.device:
  28. """Identify the device for positioning the LoRA tensors."""
  29. device = None
  30. try:
  31. device = base_layer.weight.device
  32. except AttributeError:
  33. try:
  34. linear_weights = base_layer.linear_weights
  35. if isinstance(linear_weights, dict):
  36. tensor_values = [
  37. v for v in linear_weights.values()
  38. if isinstance(v, torch.Tensor)
  39. ]
  40. if tensor_values:
  41. device = tensor_values[0].device
  42. except AttributeError:
  43. pass
  44. if device is None:
  45. raise ValueError(f"Base layer not supported: {base_layer}")
  46. return device
  47. def _apply_lora(
  48. x: torch.Tensor,
  49. lora_a_stacked: torch.Tensor,
  50. lora_b_stacked: torch.Tensor,
  51. indices: torch.Tensor,
  52. output: torch.Tensor,
  53. ):
  54. """Applies lora to each input.
  55. This method applies all loras to each input. It uses the
  56. indices vector to determine which lora yields the
  57. correct output. An index of -1 means no lora should be
  58. applied. This method adds the final lora results to the
  59. output.
  60. Input shapes:
  61. x: (batch_size, hidden_dim)
  62. lora_a_stacked: (num_loras, lora_rank, hidden_dim)
  63. lora_b_stacked: (num_loras, output_dim, lora_rank)
  64. indices: (batch_size)
  65. output: (batch_size, output_dim)
  66. """
  67. org_output = output
  68. x = x.view(-1, x.shape[-1])
  69. output = output.view(-1, output.shape[-1])
  70. indices = indices.view(-1)
  71. add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
  72. return output.view_as(org_output)
  73. def _apply_lora_packed_nslice(
  74. x: torch.Tensor,
  75. lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  76. lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  77. indices: torch.Tensor,
  78. output: torch.Tensor,
  79. output_slices: Tuple[int, ...],
  80. ):
  81. """Applies lora to each input.
  82. This method applies all loras to each input. It uses the
  83. indices vector to determine which lora yields the
  84. correct output. An index of -1 means no lora should be
  85. applied. This method adds the final lora results to the
  86. output.
  87. This method is used for layers that are composed of multiple sublayers
  88. (slices) packed together.
  89. Input shapes:
  90. x: (batch_size, hidden_dim)
  91. lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
  92. lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
  93. indices: (batch_size)
  94. output: (batch_size, q_slice_size + 2*kv_slice_size)
  95. output_slices: n-1 element tuple of (slice_size...), where n is number of slices
  96. """
  97. org_output = output
  98. x = x.view(-1, x.shape[-1])
  99. output = output.view(-1, output.shape[-1])
  100. indices = indices.view(-1)
  101. offset_left = 0
  102. for slice_idx in range(len(output_slices)):
  103. add_lora_slice(output, x, lora_a_stacked[slice_idx],
  104. lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
  105. output_slices[slice_idx])
  106. offset_left += output_slices[slice_idx]
  107. return output.view_as(org_output)
  108. @dataclass
  109. class LoRAMapping:
  110. # Per every token in input_ids:
  111. index_mapping: Tuple[int, ...]
  112. # Per sampled token:
  113. prompt_mapping: Tuple[int, ...]
  114. def __post_init__(self):
  115. self.index_mapping = tuple(self.index_mapping)
  116. self.prompt_mapping = tuple(self.prompt_mapping)
  117. class BaseLayerWithLoRA(nn.Module):
  118. def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
  119. model_config: PretrainedConfig) -> None:
  120. """Initializes lora matrices."""
  121. ...
  122. def reset_lora(self, index: int):
  123. """Resets the lora weights at index back to 0."""
  124. ...
  125. def set_lora(
  126. self,
  127. index: int,
  128. lora_a: torch.Tensor,
  129. lora_b: torch.Tensor,
  130. embeddings_tensor: Optional[torch.Tensor],
  131. ):
  132. """Overwrites lora tensors at index."""
  133. ...
  134. def set_mapping(
  135. self,
  136. base_indices: torch.Tensor,
  137. sampler_indices: torch.Tensor,
  138. sampler_indices_padded: torch.Tensor,
  139. embeddings_indices: torch.Tensor,
  140. indices_len: List[int],
  141. ):
  142. """Sets the mapping indices."""
  143. ...
  144. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
  145. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  146. super().__init__()
  147. self.base_layer = base_layer
  148. def create_lora_weights(
  149. self,
  150. max_loras: int,
  151. lora_config: LoRAConfig,
  152. model_config: Optional[PretrainedConfig] = None) -> None:
  153. lora_vocab_start_idx = self.base_layer.org_vocab_size
  154. weights_idx = None
  155. if self.base_layer.vocab_end_index > lora_vocab_start_idx:
  156. # We can start adding lora weights
  157. weights_idx = max(
  158. lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
  159. self.embeddings_slice = (self.base_layer.vocab_start_index -
  160. self.base_layer.org_vocab_size +
  161. weights_idx,
  162. self.base_layer.vocab_end_index -
  163. self.base_layer.org_vocab_size)
  164. self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
  165. self.embeddings_weights.fill_(0)
  166. else:
  167. self.embeddings_slice = None
  168. self.embeddings_weights = None
  169. self.embeddings_tensors = torch.zeros(
  170. (
  171. max_loras,
  172. lora_config.lora_extra_vocab_size,
  173. self.base_layer.embedding_dim,
  174. ),
  175. dtype=self.base_layer.weight.dtype,
  176. device=self.base_layer.weight.device,
  177. )
  178. self.lora_a_stacked = torch.zeros(
  179. (
  180. max_loras,
  181. self.base_layer.org_vocab_size +
  182. lora_config.lora_extra_vocab_size,
  183. lora_config.max_lora_rank,
  184. ),
  185. dtype=lora_config.lora_dtype,
  186. device=self.base_layer.weight.device,
  187. )
  188. self.lora_b_stacked = torch.zeros(
  189. (
  190. max_loras,
  191. 1,
  192. self.base_layer.embedding_dim,
  193. lora_config.max_lora_rank,
  194. ),
  195. dtype=lora_config.lora_dtype,
  196. device=self.base_layer.weight.device,
  197. )
  198. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  199. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  200. self.lora_a_stacked.shape[2],
  201. )
  202. self.indices: Optional[torch.Tensor] = None
  203. self.indices_len: Optional[List[int]] = None
  204. self.embeddings_indices = None
  205. def reset_lora(self, index: int):
  206. self.lora_a_stacked[index] = 0
  207. self.lora_b_stacked[index] = 0
  208. self.embeddings_tensors[index] = 0
  209. def set_lora(
  210. self,
  211. index: int,
  212. lora_a: torch.Tensor,
  213. lora_b: torch.Tensor,
  214. embeddings_tensor: Optional[torch.Tensor],
  215. ):
  216. self.reset_lora(index)
  217. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  218. lora_a, non_blocking=True)
  219. self.lora_b_stacked[index,
  220. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  221. lora_b.T, non_blocking=True)
  222. if embeddings_tensor is not None:
  223. self.embeddings_tensors[
  224. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  225. shape[1]].copy_(embeddings_tensor, non_blocking=True)
  226. if self.embeddings_slice is not None:
  227. # TODO(yard1): Optimize this copy, we don't need to copy
  228. # everything, just the modified part
  229. embeddings = self.embeddings_tensors.view(
  230. self.embeddings_tensors.shape[0] *
  231. self.embeddings_tensors.shape[1],
  232. self.embeddings_tensors.shape[2]
  233. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  234. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  235. def set_mapping(
  236. self,
  237. base_indices: torch.Tensor,
  238. sampler_indices: torch.Tensor,
  239. sampler_indices_padded: torch.Tensor,
  240. embeddings_indices: torch.Tensor,
  241. indices_len: List[int],
  242. ):
  243. self.indices = base_indices
  244. self.embeddings_indices = embeddings_indices
  245. self.indices_len = indices_len
  246. def forward(self, x: torch.Tensor) -> torch.Tensor:
  247. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  248. indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
  249. full_lora_a_embeddings = F.embedding(
  250. x + indices,
  251. self.lora_a_stacked_2d,
  252. )
  253. indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
  254. full_output = self.base_layer.forward(
  255. x.add_(indices * added_tokens_mask))
  256. full_output_org = full_output
  257. if full_output.ndim == 3:
  258. full_output = full_output.view(
  259. full_output.shape[0] * full_output.shape[1], -1)
  260. if full_lora_a_embeddings.ndim == 3:
  261. full_lora_a_embeddings = full_lora_a_embeddings.view(
  262. full_lora_a_embeddings.shape[0] *
  263. full_lora_a_embeddings.shape[1], -1)
  264. bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
  265. self.indices[:self.indices_len[0]], 0, 1.0)
  266. return full_output.view_as(full_output_org)
  267. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  268. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  269. super().__init__()
  270. self.base_layer = base_layer
  271. def create_lora_weights(
  272. self,
  273. max_loras: int,
  274. lora_config: LoRAConfig,
  275. model_config: Optional[PretrainedConfig] = None) -> None:
  276. self.lora_a_stacked = torch.zeros(
  277. max_loras,
  278. 1,
  279. lora_config.max_lora_rank,
  280. self.base_layer.weight.shape[1],
  281. dtype=lora_config.lora_dtype,
  282. device=self.base_layer.weight.device,
  283. )
  284. self.lora_b_stacked = torch.zeros(
  285. max_loras,
  286. 1,
  287. self.base_layer.weight.shape[0],
  288. lora_config.max_lora_rank,
  289. dtype=lora_config.lora_dtype,
  290. device=self.base_layer.weight.device,
  291. )
  292. self.indices: Optional[torch.Tensor] = None
  293. self.indices_len: Optional[List[int]] = None
  294. self.output_dim = self.lora_b_stacked.shape[1]
  295. def reset_lora(self, index: int):
  296. self.lora_a_stacked[index] = 0
  297. self.lora_b_stacked[index] = 0
  298. def set_lora(
  299. self,
  300. index: int,
  301. lora_a: torch.Tensor,
  302. lora_b: torch.Tensor,
  303. embeddings_tensor: Optional[torch.Tensor],
  304. ):
  305. self.reset_lora(index)
  306. self.lora_a_stacked[index,
  307. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  308. lora_a.T, non_blocking=True)
  309. self.lora_b_stacked[index,
  310. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  311. lora_b.T, non_blocking=True)
  312. def set_mapping(
  313. self,
  314. base_indices: torch.Tensor,
  315. sampler_indices: torch.Tensor,
  316. sampler_indices_padded: torch.Tensor,
  317. embeddings_indices: torch.Tensor,
  318. indices_len: List[int],
  319. ):
  320. self.indices = base_indices
  321. self.indices_len = indices_len
  322. def apply_weights(self, x: torch.Tensor,
  323. bias: Optional[torch.Tensor]) -> torch.Tensor:
  324. output = self.base_layer.linear_method.apply_weights(
  325. self.base_layer.linear_weights, x, bias)
  326. _apply_lora(
  327. x,
  328. self.lora_a_stacked,
  329. self.lora_b_stacked,
  330. self.indices[:self.indices_len[0]],
  331. output,
  332. )
  333. return output
  334. def forward(self, input_):
  335. """Forward of ColumnParallelLinear
  336. Args:
  337. input_: Tensor whose last dimension is `input_size`.
  338. Returns:
  339. - output
  340. - bias
  341. """
  342. bias = (self.base_layer.bias
  343. if not self.base_layer.skip_bias_add else None)
  344. # Matrix multiply.
  345. output_parallel = self.apply_weights(input_, bias)
  346. if self.base_layer.gather_output:
  347. # All-gather across the partitions.
  348. output = tensor_model_parallel_all_gather(output_parallel)
  349. else:
  350. output = output_parallel
  351. output_bias = (self.base_layer.bias
  352. if self.base_layer.skip_bias_add else None)
  353. return output, output_bias
  354. @property
  355. def linear_weights(self):
  356. return self.base_layer.linear_weights
  357. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  358. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  359. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  360. This means we have 2 LoRAs, each applied to one half of the layer.
  361. Both slices must have the same size.
  362. """
  363. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  364. super().__init__(base_layer)
  365. def create_lora_weights(
  366. self,
  367. max_loras: int,
  368. lora_config: LoRAConfig,
  369. model_config: Optional[PretrainedConfig] = None) -> None:
  370. n_slices = 2
  371. if not (len(self.base_layer.output_sizes) == n_slices
  372. and self.base_layer.output_sizes[0]
  373. == self.base_layer.output_sizes[1]):
  374. raise ValueError(
  375. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  376. "the same size.")
  377. self.tp_size = get_tensor_model_parallel_world_size()
  378. device = _get_lora_device(self.base_layer)
  379. self.lora_a_stacked = tuple(
  380. torch.zeros(
  381. max_loras,
  382. 1,
  383. lora_config.max_lora_rank,
  384. self.base_layer.input_size,
  385. dtype=lora_config.lora_dtype,
  386. device=device,
  387. ) for _ in range(n_slices))
  388. self.lora_b_stacked = tuple(
  389. torch.zeros(
  390. max_loras,
  391. 1,
  392. self.base_layer.output_size // 2,
  393. lora_config.max_lora_rank,
  394. dtype=lora_config.lora_dtype,
  395. device=device,
  396. ) for _ in range(n_slices))
  397. self.indices: Optional[torch.Tensor] = None
  398. self.output_dim = self.lora_b_stacked[0].shape[2]
  399. def reset_lora(self, index: int):
  400. self.lora_a_stacked[0][index] = 0
  401. self.lora_a_stacked[1][index] = 0
  402. self.lora_b_stacked[0][index] = 0
  403. self.lora_b_stacked[1][index] = 0
  404. def set_lora(
  405. self,
  406. index: int,
  407. lora_a: torch.Tensor,
  408. lora_b: torch.Tensor,
  409. embeddings_tensor: Optional[torch.Tensor],
  410. ):
  411. self.reset_lora(index)
  412. if self.tp_size > 1:
  413. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  414. shard_size = self.output_dim
  415. start_idx = tensor_model_parallel_rank * shard_size
  416. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  417. lora_b = lora_b[0][:,
  418. start_idx:end_idx], lora_b[1][:,
  419. start_idx:end_idx]
  420. if lora_a[0] is not None:
  421. self.lora_a_stacked[0][
  422. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  423. lora_a[0].T, non_blocking=True)
  424. self.lora_b_stacked[0][
  425. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  426. lora_b[0].T, non_blocking=True)
  427. if lora_a[1] is not None:
  428. self.lora_a_stacked[1][
  429. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  430. lora_a[1].T, non_blocking=True)
  431. self.lora_b_stacked[1][
  432. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  433. lora_b[1].T, non_blocking=True)
  434. def apply_weights(self, x: torch.Tensor,
  435. bias: Optional[torch.Tensor]) -> torch.Tensor:
  436. output = self.base_layer.linear_method.apply_weights(
  437. self.base_layer.linear_weights, x, bias)
  438. _apply_lora_packed_nslice(
  439. x,
  440. self.lora_a_stacked,
  441. self.lora_b_stacked,
  442. self.indices[:self.indices_len[0]],
  443. output,
  444. (self.output_dim, self.output_dim),
  445. )
  446. return output
  447. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  448. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  449. packed together in qkv proj fashion
  450. (q_proj + k_proj + v_proj -> qkv_proj).
  451. This means we have 3 LoRAs, each applied to one slice of the layer.
  452. Q slice may have different shape than K and V slices (which both have
  453. the same shape).
  454. """
  455. def __init__(self, base_layer: QKVParallelLinear) -> None:
  456. super().__init__(base_layer)
  457. def create_lora_weights(
  458. self,
  459. max_loras: int,
  460. lora_config: LoRAConfig,
  461. model_config: Optional[PretrainedConfig] = None) -> None:
  462. self.tp_size = get_tensor_model_parallel_world_size()
  463. tp_rank = get_tensor_model_parallel_rank()
  464. self.q_proj_shard_size = (self.base_layer.num_heads *
  465. self.base_layer.head_size)
  466. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  467. self.base_layer.head_size)
  468. self.q_shard_id = tp_rank
  469. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  470. device = _get_lora_device(self.base_layer)
  471. # q, k, v
  472. self.lora_a_stacked = (
  473. torch.zeros(
  474. max_loras,
  475. 1,
  476. lora_config.max_lora_rank,
  477. self.base_layer.input_size,
  478. dtype=lora_config.lora_dtype,
  479. device=device,
  480. ),
  481. torch.zeros(
  482. max_loras,
  483. 1,
  484. lora_config.max_lora_rank,
  485. self.base_layer.input_size,
  486. dtype=lora_config.lora_dtype,
  487. device=device,
  488. ),
  489. torch.zeros(
  490. max_loras,
  491. 1,
  492. lora_config.max_lora_rank,
  493. self.base_layer.input_size,
  494. dtype=lora_config.lora_dtype,
  495. device=device,
  496. ),
  497. )
  498. self.lora_b_stacked = (
  499. torch.zeros(
  500. max_loras,
  501. 1,
  502. self.q_proj_shard_size,
  503. lora_config.max_lora_rank,
  504. dtype=lora_config.lora_dtype,
  505. device=device,
  506. ),
  507. torch.zeros(
  508. max_loras,
  509. 1,
  510. self.kv_proj_shard_size,
  511. lora_config.max_lora_rank,
  512. dtype=lora_config.lora_dtype,
  513. device=device,
  514. ),
  515. torch.zeros(
  516. max_loras,
  517. 1,
  518. self.kv_proj_shard_size,
  519. lora_config.max_lora_rank,
  520. dtype=lora_config.lora_dtype,
  521. device=device,
  522. ),
  523. )
  524. self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
  525. self.kv_proj_shard_size)
  526. self.packed_indices: Optional[torch.Tensor] = None
  527. self.standard_indices: Optional[torch.Tensor] = None
  528. self.indices_len: Optional[List[int]] = None
  529. def reset_lora(self, index: int):
  530. self.lora_a_stacked[0][index] = 0
  531. self.lora_b_stacked[0][index] = 0
  532. self.lora_a_stacked[1][index] = 0
  533. self.lora_b_stacked[1][index] = 0
  534. self.lora_a_stacked[2][index] = 0
  535. self.lora_b_stacked[2][index] = 0
  536. def set_lora(
  537. self,
  538. index: int,
  539. lora_a: torch.Tensor,
  540. lora_b: torch.Tensor,
  541. embeddings_tensor: Optional[torch.Tensor],
  542. ):
  543. self.reset_lora(index)
  544. if self.tp_size > 1:
  545. if lora_b[0] is not None:
  546. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  547. self.q_shard_id:self.q_proj_shard_size *
  548. (self.q_shard_id + 1)]
  549. self.lora_b_stacked[0][
  550. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  551. lora_b_q.T, non_blocking=True)
  552. if lora_b[1] is not None:
  553. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  554. self.kv_shard_id:self.kv_proj_shard_size *
  555. (self.kv_shard_id + 1)]
  556. self.lora_b_stacked[1][
  557. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  558. lora_b_k.T, non_blocking=True)
  559. if lora_b[2] is not None:
  560. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  561. self.kv_shard_id:self.kv_proj_shard_size *
  562. (self.kv_shard_id + 1)]
  563. self.lora_b_stacked[2][
  564. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  565. lora_b_v.T, non_blocking=True)
  566. else:
  567. if lora_b[0] is not None:
  568. self.lora_b_stacked[0][
  569. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  570. lora_b[0].T, non_blocking=True)
  571. if lora_b[1] is not None:
  572. self.lora_b_stacked[1][
  573. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  574. lora_b[1].T, non_blocking=True)
  575. if lora_b[2] is not None:
  576. self.lora_b_stacked[2][
  577. index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
  578. lora_b[2].T, non_blocking=True)
  579. if lora_a[0] is not None:
  580. self.lora_a_stacked[0][
  581. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  582. lora_a[0].T, non_blocking=True)
  583. if lora_a[1] is not None:
  584. self.lora_a_stacked[1][
  585. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  586. lora_a[1].T, non_blocking=True)
  587. if lora_a[2] is not None:
  588. self.lora_a_stacked[2][
  589. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  590. lora_a[2].T, non_blocking=True)
  591. def apply_weights(self, x: torch.Tensor,
  592. bias: Optional[torch.Tensor]) -> torch.Tensor:
  593. output = self.base_layer.linear_method.apply_weights(
  594. self.base_layer.linear_weights, x, bias)
  595. _apply_lora_packed_nslice(
  596. x,
  597. self.lora_a_stacked,
  598. self.lora_b_stacked,
  599. self.indices[:self.indices_len[0]],
  600. output,
  601. self.output_slices,
  602. )
  603. return output
  604. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  605. def __init__(self, base_layer: RowParallelLinear) -> None:
  606. super().__init__()
  607. self.base_layer = base_layer
  608. def create_lora_weights(
  609. self,
  610. max_loras: int,
  611. lora_config: LoRAConfig,
  612. model_config: Optional[PretrainedConfig] = None) -> None:
  613. device = _get_lora_device(self.base_layer)
  614. self.lora_a_stacked = torch.zeros(
  615. (
  616. max_loras,
  617. 1,
  618. lora_config.max_lora_rank,
  619. self.base_layer.input_size,
  620. ),
  621. dtype=lora_config.lora_dtype,
  622. device=device,
  623. )
  624. self.lora_b_stacked = torch.zeros(
  625. (
  626. max_loras,
  627. 1,
  628. self.base_layer.output_size,
  629. lora_config.max_lora_rank,
  630. ),
  631. dtype=lora_config.lora_dtype,
  632. device=device,
  633. )
  634. self.indices: Optional[torch.Tensor] = None
  635. self.indices_len: Optional[List[int]] = None
  636. def reset_lora(self, index: int):
  637. self.lora_a_stacked[index] = 0
  638. self.lora_b_stacked[index] = 0
  639. def set_lora(
  640. self,
  641. index: int,
  642. lora_a: torch.Tensor,
  643. lora_b: torch.Tensor,
  644. embeddings_tensor: Optional[torch.Tensor],
  645. ):
  646. self.reset_lora(index)
  647. if self.base_layer.tp_size > 1:
  648. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  649. shard_size = self.base_layer.weight.shape[1]
  650. start_idx = tensor_model_parallel_rank * shard_size
  651. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  652. lora_a = lora_a[start_idx:end_idx, :]
  653. self.lora_a_stacked[index,
  654. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  655. lora_a.T, non_blocking=True)
  656. self.lora_b_stacked[index,
  657. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  658. lora_b.T, non_blocking=True)
  659. def set_mapping(
  660. self,
  661. base_indices: torch.Tensor,
  662. sampler_indices: torch.Tensor,
  663. sampler_indices_padded: torch.Tensor,
  664. embeddings_indices: torch.Tensor,
  665. indices_len: List[int],
  666. ):
  667. self.indices = base_indices
  668. self.indices_len = indices_len
  669. def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
  670. output = self.base_layer.linear_method.apply_weights(
  671. self.base_layer.linear_weights, x)
  672. _apply_lora(
  673. x,
  674. self.lora_a_stacked,
  675. self.lora_b_stacked,
  676. self.indices[:self.indices_len[0]],
  677. output,
  678. )
  679. return output
  680. def forward(self, input_):
  681. """Forward of RowParallelLinear
  682. Args:
  683. input_: tensor whose last dimension is `input_size`. If
  684. `input_is_parallel` is set, then the last dimension
  685. is `input_size // tp_size`.
  686. Returns:
  687. - output
  688. - bias
  689. """
  690. # Set up backprop all-reduce.
  691. if self.base_layer.input_is_parallel:
  692. input_parallel = input_
  693. else:
  694. # TODO: simplify code below
  695. tp_rank = get_tensor_model_parallel_rank()
  696. splitted_input = split_tensor_along_last_dim(
  697. input_, num_partitions=self.base_layer.tp_size)
  698. input_parallel = splitted_input[tp_rank].contiguous()
  699. # Matrix multiply.
  700. output_parallel = self.apply_weights(input_parallel)
  701. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  702. output_ = tensor_model_parallel_all_reduce(output_parallel)
  703. else:
  704. output_ = output_parallel
  705. if not self.base_layer.skip_bias_add:
  706. output = (output_ + self.base_layer.bias
  707. if self.base_layer.bias is not None else output_)
  708. output_bias = None
  709. else:
  710. output = output_
  711. output_bias = self.base_layer.bias
  712. return output, output_bias
  713. @property
  714. def weight(self):
  715. return self.base_layer.weight
  716. class SamplerWithLoRA(BaseLayerWithLoRA):
  717. def __init__(
  718. self,
  719. base_layer: Sampler,
  720. hidden_size: int,
  721. dtype: torch.dtype,
  722. device: torch.device,
  723. ) -> None:
  724. super().__init__()
  725. self.base_layer = base_layer
  726. self.hidden_size = hidden_size
  727. self.dtype = dtype
  728. self.device = device
  729. @property
  730. def vocab_size(self):
  731. return self.base_layer.vocab_size
  732. @property
  733. def org_vocab_size(self):
  734. return self.base_layer.org_vocab_size
  735. @property
  736. def include_gpu_probs_tensor(self):
  737. return self.base_layer.include_gpu_probs_tensor
  738. def create_lora_weights(
  739. self,
  740. max_loras: int,
  741. lora_config: LoRAConfig,
  742. model_config: Optional[PretrainedConfig] = None,
  743. ) -> None:
  744. # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
  745. if 32000 < self.base_layer.vocab_size > 33024:
  746. raise ValueError(
  747. "When using LoRA, vocab size must be 32000 >= vocab_size <= 33024"
  748. )
  749. self.lora_a_stacked = torch.zeros(
  750. (
  751. max_loras,
  752. 1,
  753. lora_config.max_lora_rank,
  754. self.hidden_size,
  755. ),
  756. dtype=lora_config.lora_dtype,
  757. device=self.device,
  758. )
  759. self.lora_b_stacked = torch.zeros(
  760. (
  761. max_loras,
  762. 1,
  763. # Pad for kernel compatibility
  764. math.ceil(self.base_layer.vocab_size /
  765. lora_config.lora_vocab_padding_size) *
  766. lora_config.lora_vocab_padding_size,
  767. lora_config.max_lora_rank,
  768. ),
  769. dtype=lora_config.lora_dtype,
  770. device=self.device,
  771. )
  772. self.embeddings_tensors = torch.full(
  773. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  774. fill_value=float("-inf"),
  775. dtype=self.dtype,
  776. device=self.device,
  777. )
  778. self.indices = None
  779. self.indices_padded = None
  780. self.indices_len = None
  781. def reset_lora(self, index: int):
  782. self.lora_a_stacked[index] = 0
  783. self.lora_b_stacked[index] = 0
  784. self.embeddings_tensors[index] = float("-inf")
  785. def set_lora(
  786. self,
  787. index: int,
  788. lora_a: torch.Tensor,
  789. lora_b: torch.Tensor,
  790. embeddings_tensor: Optional[torch.Tensor],
  791. ):
  792. self.reset_lora(index)
  793. self.lora_a_stacked[index,
  794. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  795. lora_a.T, non_blocking=True)
  796. self.lora_b_stacked[index,
  797. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  798. lora_b.T, non_blocking=True)
  799. if embeddings_tensor is not None:
  800. self.embeddings_tensors[
  801. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  802. shape[1], ] = embeddings_tensor
  803. def set_mapping(
  804. self,
  805. base_indices: torch.Tensor,
  806. sampler_indices: torch.Tensor,
  807. sampler_indices_padded: torch.Tensor,
  808. embeddings_indices: torch.Tensor,
  809. indices_len: List[int],
  810. ):
  811. self.indices = sampler_indices
  812. self.indices_padded = sampler_indices_padded
  813. self.indices_len = indices_len
  814. def _get_logits(
  815. self,
  816. hidden_states: torch.Tensor,
  817. embedding: torch.Tensor,
  818. embedding_bias: Optional[torch.Tensor] = None,
  819. ) -> torch.Tensor:
  820. # Get the logits for the next tokens.
  821. logits = torch.matmul(hidden_states, embedding.t())
  822. if embedding_bias is not None:
  823. logits += embedding_bias
  824. logits = tensor_model_parallel_gather(logits)
  825. if logits is None:
  826. return None
  827. lora_logits = torch.empty(
  828. self.embeddings_tensors.shape[0] + 1,
  829. self.embeddings_tensors.shape[1],
  830. hidden_states.shape[0],
  831. dtype=self.embeddings_tensors.dtype,
  832. device=self.embeddings_tensors.device,
  833. )
  834. torch.matmul(self.embeddings_tensors,
  835. hidden_states.T,
  836. out=lora_logits[:-1])
  837. lora_logits[-1] = float("-inf")
  838. lora_logits = lora_logits.mT
  839. lora_logits = (lora_logits.reshape(
  840. lora_logits.shape[0] * lora_logits.shape[1],
  841. lora_logits.shape[2],
  842. ).index_select(0,
  843. self.indices_padded[:self.indices_len[2]]).nan_to_num_(
  844. nan=float("-inf"),
  845. posinf=float("inf"),
  846. neginf=float("-inf")))
  847. logits[:,
  848. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  849. lora_logits.shape[1]] = lora_logits
  850. _apply_lora(
  851. hidden_states,
  852. self.lora_a_stacked,
  853. self.lora_b_stacked,
  854. self.indices[:self.indices_len[1]],
  855. logits,
  856. )
  857. # Remove paddings in vocab (if any).
  858. logits = logits[:, :self.base_layer.vocab_size]
  859. return logits
  860. def forward(self, *args, **kwargs):
  861. return type(self.base_layer).forward(self, *args, **kwargs)
  862. def from_layer(
  863. layer: nn.Module,
  864. max_loras: int,
  865. lora_config: LoRAConfig,
  866. model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
  867. supported_layer_types = {
  868. VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
  869. ColumnParallelLinear: ColumnParallelLinearWithLoRA,
  870. QKVParallelLinear: QKVParallelLinearWithLora,
  871. MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
  872. RowParallelLinear: RowParallelLinearWithLoRA,
  873. }
  874. for src_layer_type, lora_layer_type in supported_layer_types.items():
  875. if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
  876. ret = lora_layer_type(layer)
  877. ret.create_lora_weights(max_loras, lora_config, model_config)
  878. return ret
  879. return layer
  880. def from_layer_sampler(
  881. layer: Sampler,
  882. lm_head: ParallelLMHead,
  883. max_loras: int,
  884. lora_config: LoRAConfig,
  885. model_config: Optional[PretrainedConfig] = None,
  886. ) -> SamplerWithLoRA:
  887. ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
  888. lm_head.weight.device)
  889. ret.create_lora_weights(max_loras, lora_config, model_config)
  890. return ret