layers.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975
  1. # pylint: disable=unused-argument
  2. import math
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from transformers import PretrainedConfig
  9. from aphrodite.common.config import LoRAConfig
  10. from aphrodite.lora.punica import add_lora, add_lora_slice, bgmv
  11. from aphrodite.modeling.layers.sampler import Sampler
  12. from aphrodite.modeling.megatron.communication_op import (
  13. tensor_model_parallel_all_gather,
  14. tensor_model_parallel_all_reduce,
  15. tensor_model_parallel_gather,
  16. )
  17. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  18. RowParallelLinear,
  19. QKVParallelLinear,
  20. MergedColumnParallelLinear)
  21. from aphrodite.modeling.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
  22. from aphrodite.modeling.megatron.parallel_state import (
  23. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  24. from aphrodite.modeling.megatron.utils import split_tensor_along_last_dim
  25. if TYPE_CHECKING:
  26. pass
  27. def _apply_lora(
  28. x: torch.Tensor,
  29. lora_a_stacked: torch.Tensor,
  30. lora_b_stacked: torch.Tensor,
  31. indices: torch.Tensor,
  32. output: torch.Tensor,
  33. ):
  34. """Applies lora to each input.
  35. This method applies all loras to each input. It uses the
  36. indices vector to determine which lora yields the
  37. correct output. An index of -1 means no lora should be
  38. applied. This method adds the final lora results to the
  39. output.
  40. Input shapes:
  41. x: (batch_size, hidden_dim)
  42. lora_a_stacked: (num_loras, lora_rank, hidden_dim)
  43. lora_b_stacked: (num_loras, output_dim, lora_rank)
  44. indices: (batch_size)
  45. output: (batch_size, output_dim)
  46. """
  47. org_output = output
  48. x = x.view(-1, x.shape[-1])
  49. output = output.view(-1, output.shape[-1])
  50. indices = indices.view(-1)
  51. add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
  52. return output.view_as(org_output)
  53. def _apply_lora_packed_nslice(
  54. x: torch.Tensor,
  55. lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  56. lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  57. indices: torch.Tensor,
  58. output: torch.Tensor,
  59. output_slices: Tuple[int, ...],
  60. ):
  61. """Applies lora to each input.
  62. This method applies all loras to each input. It uses the
  63. indices vector to determine which lora yields the
  64. correct output. An index of -1 means no lora should be
  65. applied. This method adds the final lora results to the
  66. output.
  67. This method is used for layers that are composed of multiple sublayers
  68. (slices) packed together.
  69. Input shapes:
  70. x: (batch_size, hidden_dim)
  71. lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
  72. lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
  73. indices: (batch_size)
  74. output: (batch_size, q_slice_size + 2*kv_slice_size)
  75. output_slices: n-1 element tuple of (slice_size...), where n is number of slices
  76. """
  77. org_output = output
  78. x = x.view(-1, x.shape[-1])
  79. output = output.view(-1, output.shape[-1])
  80. indices = indices.view(-1)
  81. offset_left = 0
  82. for slice_idx in range(len(output_slices)):
  83. add_lora_slice(output, x, lora_a_stacked[slice_idx],
  84. lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
  85. output_slices[slice_idx])
  86. offset_left += output_slices[slice_idx]
  87. return output.view_as(org_output)
  88. @dataclass
  89. class LoRAMapping:
  90. # Per every token in input_ids:
  91. index_mapping: Tuple[int, ...]
  92. # Per sampled token:
  93. prompt_mapping: Tuple[int, ...]
  94. def __post_init__(self):
  95. self.index_mapping = tuple(self.index_mapping)
  96. self.prompt_mapping = tuple(self.prompt_mapping)
  97. class BaseLayerWithLoRA(nn.Module):
  98. def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
  99. model_config: PretrainedConfig) -> None:
  100. """Initializes lora matrices."""
  101. ...
  102. def reset_lora(self, index: int):
  103. """Resets the lora weights at index back to 0."""
  104. ...
  105. def set_lora(
  106. self,
  107. index: int,
  108. lora_a: torch.Tensor,
  109. lora_b: torch.Tensor,
  110. embeddings_tensor: Optional[torch.Tensor],
  111. ):
  112. """Overwrites lora tensors at index."""
  113. ...
  114. def set_mapping(
  115. self,
  116. base_indices: torch.Tensor,
  117. sampler_indices: torch.Tensor,
  118. sampler_indices_padded: torch.Tensor,
  119. embeddings_indices: torch.Tensor,
  120. indices_len: List[int],
  121. ):
  122. """Sets the mapping indices."""
  123. ...
  124. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
  125. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  126. super().__init__()
  127. self.base_layer = base_layer
  128. def create_lora_weights(
  129. self,
  130. max_loras: int,
  131. lora_config: LoRAConfig,
  132. model_config: Optional[PretrainedConfig] = None) -> None:
  133. lora_vocab_start_idx = self.base_layer.org_vocab_size
  134. weights_idx = None
  135. if self.base_layer.vocab_end_index > lora_vocab_start_idx:
  136. # We can start adding lora weights
  137. weights_idx = max(
  138. lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
  139. self.embeddings_slice = (self.base_layer.vocab_start_index -
  140. self.base_layer.org_vocab_size +
  141. weights_idx,
  142. self.base_layer.vocab_end_index -
  143. self.base_layer.org_vocab_size)
  144. self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
  145. self.embeddings_weights.fill_(0)
  146. else:
  147. self.embeddings_slice = None
  148. self.embeddings_weights = None
  149. self.embeddings_tensors = torch.zeros(
  150. (
  151. max_loras,
  152. lora_config.lora_extra_vocab_size,
  153. self.base_layer.embedding_dim,
  154. ),
  155. dtype=self.base_layer.weight.dtype,
  156. device=self.base_layer.weight.device,
  157. )
  158. self.lora_a_stacked = torch.zeros(
  159. (
  160. max_loras,
  161. self.base_layer.org_vocab_size +
  162. lora_config.lora_extra_vocab_size,
  163. lora_config.max_lora_rank,
  164. ),
  165. dtype=lora_config.lora_dtype,
  166. device=self.base_layer.weight.device,
  167. )
  168. self.lora_b_stacked = torch.zeros(
  169. (
  170. max_loras,
  171. 1,
  172. self.base_layer.embedding_dim,
  173. lora_config.max_lora_rank,
  174. ),
  175. dtype=lora_config.lora_dtype,
  176. device=self.base_layer.weight.device,
  177. )
  178. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  179. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  180. self.lora_a_stacked.shape[2],
  181. )
  182. self.indices: Optional[torch.Tensor] = None
  183. self.indices_len: Optional[List[int]] = None
  184. self.embeddings_indices = None
  185. def reset_lora(self, index: int):
  186. self.lora_a_stacked[index] = 0
  187. self.lora_b_stacked[index] = 0
  188. self.embeddings_tensors[index] = 0
  189. def set_lora(
  190. self,
  191. index: int,
  192. lora_a: torch.Tensor,
  193. lora_b: torch.Tensor,
  194. embeddings_tensor: Optional[torch.Tensor],
  195. ):
  196. self.reset_lora(index)
  197. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  198. lora_a, non_blocking=True)
  199. self.lora_b_stacked[index,
  200. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  201. lora_b.T, non_blocking=True)
  202. if embeddings_tensor is not None:
  203. self.embeddings_tensors[
  204. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  205. shape[1]].copy_(embeddings_tensor, non_blocking=True)
  206. if self.embeddings_slice is not None:
  207. # TODO(yard1): Optimize this copy, we don't need to copy
  208. # everything, just the modified part
  209. embeddings = self.embeddings_tensors.view(
  210. self.embeddings_tensors.shape[0] *
  211. self.embeddings_tensors.shape[1],
  212. self.embeddings_tensors.shape[2]
  213. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  214. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  215. def set_mapping(
  216. self,
  217. base_indices: torch.Tensor,
  218. sampler_indices: torch.Tensor,
  219. sampler_indices_padded: torch.Tensor,
  220. embeddings_indices: torch.Tensor,
  221. indices_len: List[int],
  222. ):
  223. self.indices = base_indices
  224. self.embeddings_indices = embeddings_indices
  225. self.indices_len = indices_len
  226. def forward(self, x: torch.Tensor) -> torch.Tensor:
  227. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  228. indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
  229. full_lora_a_embeddings = F.embedding(
  230. x + indices,
  231. self.lora_a_stacked_2d,
  232. )
  233. indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
  234. full_output = self.base_layer.forward(
  235. x.add_(indices * added_tokens_mask))
  236. full_output_org = full_output
  237. if full_output.ndim == 3:
  238. full_output = full_output.view(
  239. full_output.shape[0] * full_output.shape[1], -1)
  240. if full_lora_a_embeddings.ndim == 3:
  241. full_lora_a_embeddings = full_lora_a_embeddings.view(
  242. full_lora_a_embeddings.shape[0] *
  243. full_lora_a_embeddings.shape[1], -1)
  244. bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
  245. self.indices[:self.indices_len[0]], 0, 1.0)
  246. return full_output.view_as(full_output_org)
  247. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  248. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  249. super().__init__()
  250. self.base_layer = base_layer
  251. def create_lora_weights(
  252. self,
  253. max_loras: int,
  254. lora_config: LoRAConfig,
  255. model_config: Optional[PretrainedConfig] = None) -> None:
  256. self.lora_a_stacked = torch.zeros(
  257. max_loras,
  258. 1,
  259. lora_config.max_lora_rank,
  260. self.base_layer.weight.shape[1],
  261. dtype=lora_config.lora_dtype,
  262. device=self.base_layer.weight.device,
  263. )
  264. self.lora_b_stacked = torch.zeros(
  265. max_loras,
  266. 1,
  267. self.base_layer.weight.shape[0],
  268. lora_config.max_lora_rank,
  269. dtype=lora_config.lora_dtype,
  270. device=self.base_layer.weight.device,
  271. )
  272. self.indices: Optional[torch.Tensor] = None
  273. self.indices_len: Optional[List[int]] = None
  274. self.output_dim = self.lora_b_stacked.shape[1]
  275. def reset_lora(self, index: int):
  276. self.lora_a_stacked[index] = 0
  277. self.lora_b_stacked[index] = 0
  278. def set_lora(
  279. self,
  280. index: int,
  281. lora_a: torch.Tensor,
  282. lora_b: torch.Tensor,
  283. embeddings_tensor: Optional[torch.Tensor],
  284. ):
  285. self.reset_lora(index)
  286. self.lora_a_stacked[index,
  287. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  288. lora_a.T, non_blocking=True)
  289. self.lora_b_stacked[index,
  290. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  291. lora_b.T, non_blocking=True)
  292. def set_mapping(
  293. self,
  294. base_indices: torch.Tensor,
  295. sampler_indices: torch.Tensor,
  296. sampler_indices_padded: torch.Tensor,
  297. embeddings_indices: torch.Tensor,
  298. indices_len: List[int],
  299. ):
  300. self.indices = base_indices
  301. self.indices_len = indices_len
  302. def apply_weights(self, x: torch.Tensor,
  303. bias: Optional[torch.Tensor]) -> torch.Tensor:
  304. output = self.base_layer.linear_method.apply_weights(
  305. self.base_layer.linear_weights, x, bias)
  306. _apply_lora(
  307. x,
  308. self.lora_a_stacked,
  309. self.lora_b_stacked,
  310. self.indices[:self.indices_len[0]],
  311. output,
  312. )
  313. return output
  314. def forward(self, input_):
  315. """Forward of ColumnParallelLinear
  316. Args:
  317. input_: Tensor whose last dimension is `input_size`.
  318. Returns:
  319. - output
  320. - bias
  321. """
  322. bias = (self.base_layer.bias
  323. if not self.base_layer.skip_bias_add else None)
  324. # Matrix multiply.
  325. output_parallel = self.apply_weights(input_, bias)
  326. if self.base_layer.gather_output:
  327. # All-gather across the partitions.
  328. output = tensor_model_parallel_all_gather(output_parallel)
  329. else:
  330. output = output_parallel
  331. output_bias = (self.base_layer.bias
  332. if self.base_layer.skip_bias_add else None)
  333. return output, output_bias
  334. @property
  335. def linear_weights(self):
  336. return self.base_layer.linear_weights
  337. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  338. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  339. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  340. This means we have 2 LoRAs, each applied to one half of the layer.
  341. Both slices must have the same size.
  342. """
  343. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  344. super().__init__(base_layer)
  345. def create_lora_weights(
  346. self,
  347. max_loras: int,
  348. lora_config: LoRAConfig,
  349. model_config: Optional[PretrainedConfig] = None) -> None:
  350. n_slices = 2
  351. if not (len(self.base_layer.output_sizes) == n_slices
  352. and self.base_layer.output_sizes[0]
  353. == self.base_layer.output_sizes[1]):
  354. raise ValueError(
  355. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  356. "the same size.")
  357. self.tp_size = get_tensor_model_parallel_world_size()
  358. self.lora_a_stacked = tuple(
  359. torch.zeros(
  360. max_loras,
  361. 1,
  362. lora_config.max_lora_rank,
  363. self.base_layer.weight.shape[1],
  364. dtype=lora_config.lora_dtype,
  365. device=self.base_layer.weight.device,
  366. ) for _ in range(n_slices))
  367. self.lora_b_stacked = tuple(
  368. torch.zeros(
  369. max_loras,
  370. 1,
  371. self.base_layer.weight.shape[0] // 2,
  372. lora_config.max_lora_rank,
  373. dtype=lora_config.lora_dtype,
  374. device=self.base_layer.weight.device,
  375. ) for _ in range(n_slices))
  376. self.indices: Optional[torch.Tensor] = None
  377. self.output_dim = self.lora_b_stacked[0].shape[2]
  378. def reset_lora(self, index: int):
  379. self.lora_a_stacked[0][index] = 0
  380. self.lora_a_stacked[1][index] = 0
  381. self.lora_b_stacked[0][index] = 0
  382. self.lora_b_stacked[1][index] = 0
  383. def set_lora(
  384. self,
  385. index: int,
  386. lora_a: torch.Tensor,
  387. lora_b: torch.Tensor,
  388. embeddings_tensor: Optional[torch.Tensor],
  389. ):
  390. self.reset_lora(index)
  391. if self.tp_size > 1:
  392. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  393. shard_size = self.output_dim
  394. start_idx = tensor_model_parallel_rank * shard_size
  395. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  396. lora_b = lora_b[0][:,
  397. start_idx:end_idx], lora_b[1][:,
  398. start_idx:end_idx]
  399. if lora_a[0] is not None:
  400. self.lora_a_stacked[0][
  401. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  402. lora_a[0].T, non_blocking=True)
  403. self.lora_b_stacked[0][
  404. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  405. lora_b[0].T, non_blocking=True)
  406. if lora_a[1] is not None:
  407. self.lora_a_stacked[1][
  408. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  409. lora_a[1].T, non_blocking=True)
  410. self.lora_b_stacked[1][
  411. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  412. lora_b[1].T, non_blocking=True)
  413. def apply_weights(self, x: torch.Tensor,
  414. bias: Optional[torch.Tensor]) -> torch.Tensor:
  415. output = self.base_layer.linear_method.apply_weights(
  416. self.base_layer.linear_weights, x, bias)
  417. _apply_lora_packed_nslice(
  418. x,
  419. self.lora_a_stacked,
  420. self.lora_b_stacked,
  421. self.indices[:self.indices_len[0]],
  422. output,
  423. (self.output_dim, self.output_dim),
  424. )
  425. return output
  426. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  427. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  428. packed together in qkv proj fashion
  429. (q_proj + k_proj + v_proj -> qkv_proj).
  430. This means we have 3 LoRAs, each applied to one slice of the layer.
  431. Q slice may have different shape than K and V slices (which both have
  432. the same shape).
  433. """
  434. def __init__(self, base_layer: QKVParallelLinear) -> None:
  435. super().__init__(base_layer)
  436. def create_lora_weights(
  437. self,
  438. max_loras: int,
  439. lora_config: LoRAConfig,
  440. model_config: Optional[PretrainedConfig] = None) -> None:
  441. self.tp_size = get_tensor_model_parallel_world_size()
  442. tp_rank = get_tensor_model_parallel_rank()
  443. self.q_proj_shard_size = (self.base_layer.num_heads *
  444. self.base_layer.head_size)
  445. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  446. self.base_layer.head_size)
  447. self.q_shard_id = tp_rank
  448. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  449. # q, k, v
  450. self.lora_a_stacked = (
  451. torch.zeros(
  452. max_loras,
  453. 1,
  454. lora_config.max_lora_rank,
  455. self.base_layer.weight.shape[1],
  456. dtype=lora_config.lora_dtype,
  457. device=self.base_layer.weight.device,
  458. ),
  459. torch.zeros(
  460. max_loras,
  461. 1,
  462. lora_config.max_lora_rank,
  463. self.base_layer.weight.shape[1],
  464. dtype=lora_config.lora_dtype,
  465. device=self.base_layer.weight.device,
  466. ),
  467. torch.zeros(
  468. max_loras,
  469. 1,
  470. lora_config.max_lora_rank,
  471. self.base_layer.weight.shape[1],
  472. dtype=lora_config.lora_dtype,
  473. device=self.base_layer.weight.device,
  474. ),
  475. )
  476. self.lora_b_stacked = (
  477. torch.zeros(
  478. max_loras,
  479. 1,
  480. self.q_proj_shard_size,
  481. lora_config.max_lora_rank,
  482. dtype=lora_config.lora_dtype,
  483. device=self.base_layer.weight.device,
  484. ),
  485. torch.zeros(
  486. max_loras,
  487. 1,
  488. self.kv_proj_shard_size,
  489. lora_config.max_lora_rank,
  490. dtype=lora_config.lora_dtype,
  491. device=self.base_layer.weight.device,
  492. ),
  493. torch.zeros(
  494. max_loras,
  495. 1,
  496. self.kv_proj_shard_size,
  497. lora_config.max_lora_rank,
  498. dtype=lora_config.lora_dtype,
  499. device=self.base_layer.weight.device,
  500. ),
  501. )
  502. self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
  503. self.kv_proj_shard_size)
  504. self.packed_indices: Optional[torch.Tensor] = None
  505. self.standard_indices: Optional[torch.Tensor] = None
  506. self.indices_len: Optional[List[int]] = None
  507. def reset_lora(self, index: int):
  508. self.lora_a_stacked[0][index] = 0
  509. self.lora_b_stacked[0][index] = 0
  510. self.lora_a_stacked[1][index] = 0
  511. self.lora_b_stacked[1][index] = 0
  512. self.lora_a_stacked[2][index] = 0
  513. self.lora_b_stacked[2][index] = 0
  514. def set_lora(
  515. self,
  516. index: int,
  517. lora_a: torch.Tensor,
  518. lora_b: torch.Tensor,
  519. embeddings_tensor: Optional[torch.Tensor],
  520. ):
  521. self.reset_lora(index)
  522. if self.tp_size > 1:
  523. if lora_b[0] is not None:
  524. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  525. self.q_shard_id:self.q_proj_shard_size *
  526. (self.q_shard_id + 1)]
  527. self.lora_b_stacked[0][
  528. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  529. lora_b_q.T, non_blocking=True)
  530. if lora_b[1] is not None:
  531. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  532. self.kv_shard_id:self.kv_proj_shard_size *
  533. (self.kv_shard_id + 1)]
  534. self.lora_b_stacked[1][
  535. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  536. lora_b_k.T, non_blocking=True)
  537. if lora_b[2] is not None:
  538. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  539. self.kv_shard_id:self.kv_proj_shard_size *
  540. (self.kv_shard_id + 1)]
  541. self.lora_b_stacked[2][
  542. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  543. lora_b_v.T, non_blocking=True)
  544. else:
  545. if lora_b[0] is not None:
  546. self.lora_b_stacked[0][
  547. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  548. lora_b[0].T, non_blocking=True)
  549. if lora_b[1] is not None:
  550. self.lora_b_stacked[1][
  551. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  552. lora_b[1].T, non_blocking=True)
  553. if lora_b[2] is not None:
  554. self.lora_b_stacked[2][
  555. index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
  556. lora_b[2].T, non_blocking=True)
  557. if lora_a[0] is not None:
  558. self.lora_a_stacked[0][
  559. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  560. lora_a[0].T, non_blocking=True)
  561. if lora_a[1] is not None:
  562. self.lora_a_stacked[1][
  563. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  564. lora_a[1].T, non_blocking=True)
  565. if lora_a[2] is not None:
  566. self.lora_a_stacked[2][
  567. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  568. lora_a[2].T, non_blocking=True)
  569. def apply_weights(self, x: torch.Tensor,
  570. bias: Optional[torch.Tensor]) -> torch.Tensor:
  571. output = self.base_layer.linear_method.apply_weights(
  572. self.base_layer.linear_weights, x, bias)
  573. _apply_lora_packed_nslice(
  574. x,
  575. self.lora_a_stacked,
  576. self.lora_b_stacked,
  577. self.indices[:self.indices_len[0]],
  578. output,
  579. self.output_slices,
  580. )
  581. return output
  582. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  583. def __init__(self, base_layer: RowParallelLinear) -> None:
  584. super().__init__()
  585. self.base_layer = base_layer
  586. def create_lora_weights(
  587. self,
  588. max_loras: int,
  589. lora_config: LoRAConfig,
  590. model_config: Optional[PretrainedConfig] = None) -> None:
  591. self.lora_a_stacked = torch.zeros(
  592. (
  593. max_loras,
  594. 1,
  595. lora_config.max_lora_rank,
  596. self.base_layer.weight.shape[1],
  597. ),
  598. dtype=lora_config.lora_dtype,
  599. device=self.base_layer.weight.device,
  600. )
  601. self.lora_b_stacked = torch.zeros(
  602. (
  603. max_loras,
  604. 1,
  605. self.base_layer.weight.shape[0],
  606. lora_config.max_lora_rank,
  607. ),
  608. dtype=lora_config.lora_dtype,
  609. device=self.base_layer.weight.device,
  610. )
  611. self.indices: Optional[torch.Tensor] = None
  612. self.indices_len: Optional[List[int]] = None
  613. def reset_lora(self, index: int):
  614. self.lora_a_stacked[index] = 0
  615. self.lora_b_stacked[index] = 0
  616. def set_lora(
  617. self,
  618. index: int,
  619. lora_a: torch.Tensor,
  620. lora_b: torch.Tensor,
  621. embeddings_tensor: Optional[torch.Tensor],
  622. ):
  623. self.reset_lora(index)
  624. if self.base_layer.tp_size > 1:
  625. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  626. shard_size = self.base_layer.weight.shape[1]
  627. start_idx = tensor_model_parallel_rank * shard_size
  628. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  629. lora_a = lora_a[start_idx:end_idx, :]
  630. self.lora_a_stacked[index,
  631. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  632. lora_a.T, non_blocking=True)
  633. self.lora_b_stacked[index,
  634. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  635. lora_b.T, non_blocking=True)
  636. def set_mapping(
  637. self,
  638. base_indices: torch.Tensor,
  639. sampler_indices: torch.Tensor,
  640. sampler_indices_padded: torch.Tensor,
  641. embeddings_indices: torch.Tensor,
  642. indices_len: List[int],
  643. ):
  644. self.indices = base_indices
  645. self.indices_len = indices_len
  646. def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
  647. output = self.base_layer.linear_method.apply_weights(
  648. self.base_layer.linear_weights, x)
  649. _apply_lora(
  650. x,
  651. self.lora_a_stacked,
  652. self.lora_b_stacked,
  653. self.indices[:self.indices_len[0]],
  654. output,
  655. )
  656. return output
  657. def forward(self, input_):
  658. """Forward of RowParallelLinear
  659. Args:
  660. input_: tensor whose last dimension is `input_size`. If
  661. `input_is_parallel` is set, then the last dimension
  662. is `input_size // tp_size`.
  663. Returns:
  664. - output
  665. - bias
  666. """
  667. # Set up backprop all-reduce.
  668. if self.base_layer.input_is_parallel:
  669. input_parallel = input_
  670. else:
  671. # TODO: simplify code below
  672. tp_rank = get_tensor_model_parallel_rank()
  673. splitted_input = split_tensor_along_last_dim(
  674. input_, num_partitions=self.base_layer.tp_size)
  675. input_parallel = splitted_input[tp_rank].contiguous()
  676. # Matrix multiply.
  677. output_parallel = self.apply_weights(input_parallel)
  678. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  679. output_ = tensor_model_parallel_all_reduce(output_parallel)
  680. else:
  681. output_ = output_parallel
  682. if not self.base_layer.skip_bias_add:
  683. output = (output_ + self.base_layer.bias
  684. if self.base_layer.bias is not None else output_)
  685. output_bias = None
  686. else:
  687. output = output_
  688. output_bias = self.base_layer.bias
  689. return output, output_bias
  690. @property
  691. def weight(self):
  692. return self.base_layer.weight
  693. class SamplerWithLoRA(BaseLayerWithLoRA):
  694. def __init__(
  695. self,
  696. base_layer: Sampler,
  697. hidden_size: int,
  698. dtype: torch.dtype,
  699. device: torch.device,
  700. ) -> None:
  701. super().__init__()
  702. self.base_layer = base_layer
  703. self.hidden_size = hidden_size
  704. self.dtype = dtype
  705. self.device = device
  706. @property
  707. def vocab_size(self):
  708. return self.base_layer.vocab_size
  709. @property
  710. def org_vocab_size(self):
  711. return self.base_layer.org_vocab_size
  712. @property
  713. def include_gpu_probs_tensor(self):
  714. return self.base_layer.include_gpu_probs_tensor
  715. def create_lora_weights(
  716. self,
  717. max_loras: int,
  718. lora_config: LoRAConfig,
  719. model_config: Optional[PretrainedConfig] = None,
  720. ) -> None:
  721. # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
  722. if 32000 < self.base_layer.vocab_size > 33024:
  723. raise ValueError(
  724. "When using LoRA, vocab size must be 32000 >= vocab_size <= 33024"
  725. )
  726. self.lora_a_stacked = torch.zeros(
  727. (
  728. max_loras,
  729. 1,
  730. lora_config.max_lora_rank,
  731. self.hidden_size,
  732. ),
  733. dtype=lora_config.lora_dtype,
  734. device=self.device,
  735. )
  736. self.lora_b_stacked = torch.zeros(
  737. (
  738. max_loras,
  739. 1,
  740. # Pad for kernel compatibility
  741. math.ceil(self.base_layer.vocab_size /
  742. lora_config.lora_vocab_padding_size) *
  743. lora_config.lora_vocab_padding_size,
  744. lora_config.max_lora_rank,
  745. ),
  746. dtype=lora_config.lora_dtype,
  747. device=self.device,
  748. )
  749. self.embeddings_tensors = torch.full(
  750. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  751. fill_value=float("-inf"),
  752. dtype=self.dtype,
  753. device=self.device,
  754. )
  755. self.indices = None
  756. self.indices_padded = None
  757. self.indices_len = None
  758. def reset_lora(self, index: int):
  759. self.lora_a_stacked[index] = 0
  760. self.lora_b_stacked[index] = 0
  761. self.embeddings_tensors[index] = float("-inf")
  762. def set_lora(
  763. self,
  764. index: int,
  765. lora_a: torch.Tensor,
  766. lora_b: torch.Tensor,
  767. embeddings_tensor: Optional[torch.Tensor],
  768. ):
  769. self.reset_lora(index)
  770. self.lora_a_stacked[index,
  771. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  772. lora_a.T, non_blocking=True)
  773. self.lora_b_stacked[index,
  774. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  775. lora_b.T, non_blocking=True)
  776. if embeddings_tensor is not None:
  777. self.embeddings_tensors[
  778. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  779. shape[1], ] = embeddings_tensor
  780. def set_mapping(
  781. self,
  782. base_indices: torch.Tensor,
  783. sampler_indices: torch.Tensor,
  784. sampler_indices_padded: torch.Tensor,
  785. embeddings_indices: torch.Tensor,
  786. indices_len: List[int],
  787. ):
  788. self.indices = sampler_indices
  789. self.indices_padded = sampler_indices_padded
  790. self.indices_len = indices_len
  791. def _get_logits(
  792. self,
  793. hidden_states: torch.Tensor,
  794. embedding: torch.Tensor,
  795. embedding_bias: Optional[torch.Tensor] = None,
  796. ) -> torch.Tensor:
  797. # Get the logits for the next tokens.
  798. logits = torch.matmul(hidden_states, embedding.t())
  799. if embedding_bias is not None:
  800. logits += embedding_bias
  801. logits = tensor_model_parallel_gather(logits)
  802. if logits is None:
  803. return None
  804. lora_logits = torch.empty(
  805. self.embeddings_tensors.shape[0] + 1,
  806. self.embeddings_tensors.shape[1],
  807. hidden_states.shape[0],
  808. dtype=self.embeddings_tensors.dtype,
  809. device=self.embeddings_tensors.device,
  810. )
  811. torch.matmul(self.embeddings_tensors,
  812. hidden_states.T,
  813. out=lora_logits[:-1])
  814. lora_logits[-1] = float("-inf")
  815. lora_logits = lora_logits.mT
  816. lora_logits = (lora_logits.reshape(
  817. lora_logits.shape[0] * lora_logits.shape[1],
  818. lora_logits.shape[2],
  819. ).index_select(0,
  820. self.indices_padded[:self.indices_len[2]]).nan_to_num_(
  821. nan=float("-inf"),
  822. posinf=float("inf"),
  823. neginf=float("-inf")))
  824. logits[:,
  825. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  826. lora_logits.shape[1]] = lora_logits
  827. _apply_lora(
  828. hidden_states,
  829. self.lora_a_stacked,
  830. self.lora_b_stacked,
  831. self.indices[:self.indices_len[1]],
  832. logits,
  833. )
  834. # Remove paddings in vocab (if any).
  835. logits = logits[:, :self.base_layer.vocab_size]
  836. return logits
  837. def forward(self, *args, **kwargs):
  838. return type(self.base_layer).forward(self, *args, **kwargs)
  839. def from_layer(
  840. layer: nn.Module,
  841. max_loras: int,
  842. lora_config: LoRAConfig,
  843. model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
  844. supported_layer_types = {
  845. VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
  846. ColumnParallelLinear: ColumnParallelLinearWithLoRA,
  847. QKVParallelLinear: QKVParallelLinearWithLora,
  848. MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
  849. RowParallelLinear: RowParallelLinearWithLoRA,
  850. }
  851. for src_layer_type, lora_layer_type in supported_layer_types.items():
  852. if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
  853. ret = lora_layer_type(layer)
  854. ret.create_lora_weights(max_loras, lora_config, model_config)
  855. return ret
  856. return layer
  857. def from_layer_sampler(
  858. layer: Sampler,
  859. lm_head: ParallelLMHead,
  860. max_loras: int,
  861. lora_config: LoRAConfig,
  862. model_config: Optional[PretrainedConfig] = None,
  863. ) -> SamplerWithLoRA:
  864. ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
  865. lm_head.weight.device)
  866. ret.create_lora_weights(max_loras, lora_config, model_config)
  867. return ret