layers.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011
  1. # pylint: disable=unused-argument
  2. import math
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from transformers import PretrainedConfig
  9. from aphrodite.common.config import LoRAConfig
  10. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  11. get_tensor_model_parallel_world_size,
  12. split_tensor_along_last_dim,
  13. tensor_model_parallel_all_gather,
  14. tensor_model_parallel_all_reduce,
  15. tensor_model_parallel_gather)
  16. from aphrodite.lora.punica import add_lora, add_lora_slice, bgmv
  17. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  18. RowParallelLinear,
  19. QKVParallelLinear,
  20. MergedColumnParallelLinear)
  21. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  22. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  23. VocabParallelEmbedding, ParallelLMHead)
  24. if TYPE_CHECKING:
  25. pass
  26. def _get_lora_device(base_layer: nn.Module) -> torch.device:
  27. """Identify the device for positioning the LoRA tensors."""
  28. device = None
  29. try:
  30. device = base_layer.weight.device
  31. except AttributeError:
  32. try:
  33. linear_weights = base_layer.linear_weights
  34. if isinstance(linear_weights, dict):
  35. tensor_values = [
  36. v for v in linear_weights.values()
  37. if isinstance(v, torch.Tensor)
  38. ]
  39. if tensor_values:
  40. device = tensor_values[0].device
  41. except AttributeError:
  42. pass
  43. if device is None:
  44. raise ValueError(f"Base layer not supported: {base_layer}")
  45. return device
  46. def _apply_lora(
  47. x: torch.Tensor,
  48. lora_a_stacked: torch.Tensor,
  49. lora_b_stacked: torch.Tensor,
  50. indices: torch.Tensor,
  51. output: torch.Tensor,
  52. ):
  53. """Applies lora to each input.
  54. This method applies all loras to each input. It uses the
  55. indices vector to determine which lora yields the
  56. correct output. An index of -1 means no lora should be
  57. applied. This method adds the final lora results to the
  58. output.
  59. Input shapes:
  60. x: (batch_size, hidden_dim)
  61. lora_a_stacked: (num_loras, lora_rank, hidden_dim)
  62. lora_b_stacked: (num_loras, output_dim, lora_rank)
  63. indices: (batch_size)
  64. output: (batch_size, output_dim)
  65. """
  66. org_output = output
  67. x = x.view(-1, x.shape[-1])
  68. output = output.view(-1, output.shape[-1])
  69. indices = indices.view(-1)
  70. add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
  71. return output.view_as(org_output)
  72. def _apply_lora_packed_nslice(
  73. x: torch.Tensor,
  74. lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  75. lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  76. indices: torch.Tensor,
  77. output: torch.Tensor,
  78. output_slices: Tuple[int, ...],
  79. ):
  80. """Applies lora to each input.
  81. This method applies all loras to each input. It uses the
  82. indices vector to determine which lora yields the
  83. correct output. An index of -1 means no lora should be
  84. applied. This method adds the final lora results to the
  85. output.
  86. This method is used for layers that are composed of multiple sublayers
  87. (slices) packed together.
  88. Input shapes:
  89. x: (batch_size, hidden_dim)
  90. lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
  91. lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
  92. indices: (batch_size)
  93. output: (batch_size, q_slice_size + 2*kv_slice_size)
  94. output_slices: n-1 element tuple of (slice_size...), where n is
  95. number of slices
  96. """
  97. org_output = output
  98. x = x.view(-1, x.shape[-1])
  99. output = output.view(-1, output.shape[-1])
  100. indices = indices.view(-1)
  101. offset_left = 0
  102. for slice_idx in range(len(output_slices)):
  103. add_lora_slice(output, x, lora_a_stacked[slice_idx],
  104. lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
  105. output_slices[slice_idx])
  106. offset_left += output_slices[slice_idx]
  107. return output.view_as(org_output)
  108. @dataclass
  109. class LoRAMapping:
  110. # Per every token in input_ids:
  111. index_mapping: Tuple[int, ...]
  112. # Per sampled token:
  113. prompt_mapping: Tuple[int, ...]
  114. def __post_init__(self):
  115. self.index_mapping = tuple(self.index_mapping)
  116. self.prompt_mapping = tuple(self.prompt_mapping)
  117. class BaseLayerWithLoRA(nn.Module):
  118. def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
  119. model_config: PretrainedConfig) -> None:
  120. """Initializes lora matrices."""
  121. ...
  122. def reset_lora(self, index: int):
  123. """Resets the lora weights at index back to 0."""
  124. ...
  125. def set_lora(
  126. self,
  127. index: int,
  128. lora_a: torch.Tensor,
  129. lora_b: torch.Tensor,
  130. embeddings_tensor: Optional[torch.Tensor],
  131. ):
  132. """Overwrites lora tensors at index."""
  133. ...
  134. def set_mapping(
  135. self,
  136. base_indices: torch.Tensor,
  137. sampler_indices: torch.Tensor,
  138. sampler_indices_padded: torch.Tensor,
  139. embeddings_indices: torch.Tensor,
  140. indices_len: List[int],
  141. ):
  142. """Sets the mapping indices."""
  143. ...
  144. class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
  145. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  146. super().__init__()
  147. self.base_layer = base_layer
  148. def create_lora_weights(
  149. self,
  150. max_loras: int,
  151. lora_config: LoRAConfig,
  152. model_config: Optional[PretrainedConfig] = None) -> None:
  153. lora_vocab_start_idx = self.base_layer.org_vocab_size
  154. weights_idx = None
  155. if self.base_layer.vocab_end_index > lora_vocab_start_idx:
  156. # We can start adding lora weights
  157. weights_idx = max(
  158. lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
  159. self.embeddings_slice = (self.base_layer.vocab_start_index -
  160. self.base_layer.org_vocab_size +
  161. weights_idx,
  162. self.base_layer.vocab_end_index -
  163. self.base_layer.org_vocab_size)
  164. self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
  165. self.embeddings_weights.fill_(0)
  166. else:
  167. self.embeddings_slice = None
  168. self.embeddings_weights = None
  169. self.embeddings_tensors = torch.zeros(
  170. (
  171. max_loras,
  172. lora_config.lora_extra_vocab_size,
  173. self.base_layer.embedding_dim,
  174. ),
  175. dtype=self.base_layer.weight.dtype,
  176. device=self.base_layer.weight.device,
  177. )
  178. self.lora_a_stacked = torch.zeros(
  179. (
  180. max_loras,
  181. self.base_layer.org_vocab_size +
  182. lora_config.lora_extra_vocab_size,
  183. lora_config.max_lora_rank,
  184. ),
  185. dtype=lora_config.lora_dtype,
  186. device=self.base_layer.weight.device,
  187. )
  188. self.lora_b_stacked = torch.zeros(
  189. (
  190. max_loras,
  191. 1,
  192. self.base_layer.embedding_dim,
  193. lora_config.max_lora_rank,
  194. ),
  195. dtype=lora_config.lora_dtype,
  196. device=self.base_layer.weight.device,
  197. )
  198. self.lora_a_stacked_2d = self.lora_a_stacked.view(
  199. self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
  200. self.lora_a_stacked.shape[2],
  201. )
  202. self.indices: Optional[torch.Tensor] = None
  203. self.indices_len: Optional[List[int]] = None
  204. self.embeddings_indices = None
  205. def reset_lora(self, index: int):
  206. self.lora_a_stacked[index] = 0
  207. self.lora_b_stacked[index] = 0
  208. self.embeddings_tensors[index] = 0
  209. def set_lora(
  210. self,
  211. index: int,
  212. lora_a: torch.Tensor,
  213. lora_b: torch.Tensor,
  214. embeddings_tensor: Optional[torch.Tensor],
  215. ):
  216. self.reset_lora(index)
  217. self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
  218. lora_a, non_blocking=True)
  219. self.lora_b_stacked[index,
  220. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  221. lora_b.T, non_blocking=True)
  222. if embeddings_tensor is not None:
  223. self.embeddings_tensors[
  224. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  225. shape[1]].copy_(embeddings_tensor, non_blocking=True)
  226. if self.embeddings_slice is not None:
  227. # TODO(yard1): Optimize this copy, we don't need to copy
  228. # everything, just the modified part
  229. embeddings = self.embeddings_tensors.view(
  230. self.embeddings_tensors.shape[0] *
  231. self.embeddings_tensors.shape[1],
  232. self.embeddings_tensors.shape[2]
  233. )[self.embeddings_slice[0]:self.embeddings_slice[1]]
  234. self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
  235. def set_mapping(
  236. self,
  237. base_indices: torch.Tensor,
  238. sampler_indices: torch.Tensor,
  239. sampler_indices_padded: torch.Tensor,
  240. embeddings_indices: torch.Tensor,
  241. indices_len: List[int],
  242. ):
  243. self.indices = base_indices
  244. self.embeddings_indices = embeddings_indices
  245. self.indices_len = indices_len
  246. def forward(self, x: torch.Tensor) -> torch.Tensor:
  247. added_tokens_mask = x > self.base_layer.org_vocab_size - 1
  248. embedding_len = self.indices_len[3]
  249. indices = self.embeddings_indices[1][:embedding_len].view_as(x)
  250. full_lora_a_embeddings = F.embedding(
  251. x + indices,
  252. self.lora_a_stacked_2d,
  253. )
  254. indices = self.embeddings_indices[0][:embedding_len].view_as(x)
  255. full_output = self.base_layer.forward(
  256. x.add_(indices * added_tokens_mask))
  257. full_output_org = full_output
  258. if full_output.ndim == 3:
  259. full_output = full_output.view(
  260. full_output.shape[0] * full_output.shape[1], -1)
  261. if full_lora_a_embeddings.ndim == 3:
  262. full_lora_a_embeddings = full_lora_a_embeddings.view(
  263. full_lora_a_embeddings.shape[0] *
  264. full_lora_a_embeddings.shape[1], -1)
  265. bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
  266. self.indices[:self.indices_len[0]], 0, 1.0)
  267. return full_output.view_as(full_output_org)
  268. class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
  269. def __init__(self, base_layer: ColumnParallelLinear) -> None:
  270. super().__init__()
  271. self.base_layer = base_layer
  272. def create_lora_weights(
  273. self,
  274. max_loras: int,
  275. lora_config: LoRAConfig,
  276. model_config: Optional[PretrainedConfig] = None) -> None:
  277. self.lora_a_stacked = torch.zeros(
  278. max_loras,
  279. 1,
  280. lora_config.max_lora_rank,
  281. self.base_layer.weight.shape[1],
  282. dtype=lora_config.lora_dtype,
  283. device=self.base_layer.weight.device,
  284. )
  285. self.lora_b_stacked = torch.zeros(
  286. max_loras,
  287. 1,
  288. self.base_layer.weight.shape[0],
  289. lora_config.max_lora_rank,
  290. dtype=lora_config.lora_dtype,
  291. device=self.base_layer.weight.device,
  292. )
  293. self.indices: Optional[torch.Tensor] = None
  294. self.indices_len: Optional[List[int]] = None
  295. self.output_dim = self.lora_b_stacked.shape[1]
  296. def reset_lora(self, index: int):
  297. self.lora_a_stacked[index] = 0
  298. self.lora_b_stacked[index] = 0
  299. def set_lora(
  300. self,
  301. index: int,
  302. lora_a: torch.Tensor,
  303. lora_b: torch.Tensor,
  304. embeddings_tensor: Optional[torch.Tensor],
  305. ):
  306. self.reset_lora(index)
  307. self.lora_a_stacked[index,
  308. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  309. lora_a.T, non_blocking=True)
  310. self.lora_b_stacked[index,
  311. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  312. lora_b.T, non_blocking=True)
  313. def set_mapping(
  314. self,
  315. base_indices: torch.Tensor,
  316. sampler_indices: torch.Tensor,
  317. sampler_indices_padded: torch.Tensor,
  318. embeddings_indices: torch.Tensor,
  319. indices_len: List[int],
  320. ):
  321. self.indices = base_indices
  322. self.indices_len = indices_len
  323. def apply_weights(self, x: torch.Tensor,
  324. bias: Optional[torch.Tensor]) -> torch.Tensor:
  325. output = self.base_layer.linear_method.apply_weights(
  326. self.base_layer.linear_weights, x, bias)
  327. _apply_lora(
  328. x,
  329. self.lora_a_stacked,
  330. self.lora_b_stacked,
  331. self.indices[:self.indices_len[0]],
  332. output,
  333. )
  334. return output
  335. def forward(self, input_):
  336. """Forward of ColumnParallelLinear
  337. Args:
  338. input_: Tensor whose last dimension is `input_size`.
  339. Returns:
  340. - output
  341. - bias
  342. """
  343. bias = (self.base_layer.bias
  344. if not self.base_layer.skip_bias_add else None)
  345. # Matrix multiply.
  346. output_parallel = self.apply_weights(input_, bias)
  347. if self.base_layer.gather_output:
  348. # All-gather across the partitions.
  349. output = tensor_model_parallel_all_gather(output_parallel)
  350. else:
  351. output = output_parallel
  352. output_bias = (self.base_layer.bias
  353. if self.base_layer.skip_bias_add else None)
  354. return output, output_bias
  355. @property
  356. def linear_weights(self):
  357. return self.base_layer.linear_weights
  358. class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
  359. """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
  360. packed together (eg. gate_proj + up_proj -> gate_up_proj).
  361. This means we have 2 LoRAs, each applied to one half of the layer.
  362. Both slices must have the same size.
  363. """
  364. def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
  365. super().__init__(base_layer)
  366. def create_lora_weights(
  367. self,
  368. max_loras: int,
  369. lora_config: LoRAConfig,
  370. model_config: Optional[PretrainedConfig] = None) -> None:
  371. n_slices = 2
  372. if not (len(self.base_layer.output_sizes) == n_slices
  373. and self.base_layer.output_sizes[0]
  374. == self.base_layer.output_sizes[1]):
  375. raise ValueError(
  376. "LoRAColumnParallelLinear2Slice requires 2 slices with "
  377. "the same size.")
  378. self.tp_size = get_tensor_model_parallel_world_size()
  379. device = _get_lora_device(self.base_layer)
  380. self.lora_a_stacked = tuple(
  381. torch.zeros(
  382. max_loras,
  383. 1,
  384. lora_config.max_lora_rank,
  385. self.base_layer.input_size,
  386. dtype=lora_config.lora_dtype,
  387. device=device,
  388. ) for _ in range(n_slices))
  389. self.lora_b_stacked = tuple(
  390. torch.zeros(
  391. max_loras,
  392. 1,
  393. self.base_layer.output_size // 2,
  394. lora_config.max_lora_rank,
  395. dtype=lora_config.lora_dtype,
  396. device=device,
  397. ) for _ in range(n_slices))
  398. self.indices: Optional[torch.Tensor] = None
  399. self.output_dim = self.lora_b_stacked[0].shape[2]
  400. def reset_lora(self, index: int):
  401. self.lora_a_stacked[0][index] = 0
  402. self.lora_a_stacked[1][index] = 0
  403. self.lora_b_stacked[0][index] = 0
  404. self.lora_b_stacked[1][index] = 0
  405. def set_lora(
  406. self,
  407. index: int,
  408. lora_a: torch.Tensor,
  409. lora_b: torch.Tensor,
  410. embeddings_tensor: Optional[torch.Tensor],
  411. ):
  412. self.reset_lora(index)
  413. if self.tp_size > 1:
  414. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  415. shard_size = self.output_dim
  416. start_idx = tensor_model_parallel_rank * shard_size
  417. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  418. lora_b = lora_b[0][:,
  419. start_idx:end_idx], lora_b[1][:,
  420. start_idx:end_idx]
  421. if lora_a[0] is not None:
  422. self.lora_a_stacked[0][
  423. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  424. lora_a[0].T, non_blocking=True)
  425. self.lora_b_stacked[0][
  426. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  427. lora_b[0].T, non_blocking=True)
  428. if lora_a[1] is not None:
  429. self.lora_a_stacked[1][
  430. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  431. lora_a[1].T, non_blocking=True)
  432. self.lora_b_stacked[1][
  433. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  434. lora_b[1].T, non_blocking=True)
  435. def apply_weights(self, x: torch.Tensor,
  436. bias: Optional[torch.Tensor]) -> torch.Tensor:
  437. output = self.base_layer.linear_method.apply_weights(
  438. self.base_layer.linear_weights, x, bias)
  439. _apply_lora_packed_nslice(
  440. x,
  441. self.lora_a_stacked,
  442. self.lora_b_stacked,
  443. self.indices[:self.indices_len[0]],
  444. output,
  445. (self.output_dim, self.output_dim),
  446. )
  447. return output
  448. class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
  449. """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
  450. packed together in qkv proj fashion
  451. (q_proj + k_proj + v_proj -> qkv_proj).
  452. This means we have 3 LoRAs, each applied to one slice of the layer.
  453. Q slice may have different shape than K and V slices (which both have
  454. the same shape).
  455. """
  456. def __init__(self, base_layer: QKVParallelLinear) -> None:
  457. super().__init__(base_layer)
  458. def create_lora_weights(
  459. self,
  460. max_loras: int,
  461. lora_config: LoRAConfig,
  462. model_config: Optional[PretrainedConfig] = None) -> None:
  463. self.tp_size = get_tensor_model_parallel_world_size()
  464. tp_rank = get_tensor_model_parallel_rank()
  465. self.q_proj_shard_size = (self.base_layer.num_heads *
  466. self.base_layer.head_size)
  467. self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
  468. self.base_layer.head_size)
  469. self.q_shard_id = tp_rank
  470. self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
  471. device = _get_lora_device(self.base_layer)
  472. # q, k, v
  473. self.lora_a_stacked = (
  474. torch.zeros(
  475. max_loras,
  476. 1,
  477. lora_config.max_lora_rank,
  478. self.base_layer.input_size,
  479. dtype=lora_config.lora_dtype,
  480. device=device,
  481. ),
  482. torch.zeros(
  483. max_loras,
  484. 1,
  485. lora_config.max_lora_rank,
  486. self.base_layer.input_size,
  487. dtype=lora_config.lora_dtype,
  488. device=device,
  489. ),
  490. torch.zeros(
  491. max_loras,
  492. 1,
  493. lora_config.max_lora_rank,
  494. self.base_layer.input_size,
  495. dtype=lora_config.lora_dtype,
  496. device=device,
  497. ),
  498. )
  499. self.lora_b_stacked = (
  500. torch.zeros(
  501. max_loras,
  502. 1,
  503. self.q_proj_shard_size,
  504. lora_config.max_lora_rank,
  505. dtype=lora_config.lora_dtype,
  506. device=device,
  507. ),
  508. torch.zeros(
  509. max_loras,
  510. 1,
  511. self.kv_proj_shard_size,
  512. lora_config.max_lora_rank,
  513. dtype=lora_config.lora_dtype,
  514. device=device,
  515. ),
  516. torch.zeros(
  517. max_loras,
  518. 1,
  519. self.kv_proj_shard_size,
  520. lora_config.max_lora_rank,
  521. dtype=lora_config.lora_dtype,
  522. device=device,
  523. ),
  524. )
  525. self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
  526. self.kv_proj_shard_size)
  527. self.packed_indices: Optional[torch.Tensor] = None
  528. self.standard_indices: Optional[torch.Tensor] = None
  529. self.indices_len: Optional[List[int]] = None
  530. def reset_lora(self, index: int):
  531. self.lora_a_stacked[0][index] = 0
  532. self.lora_b_stacked[0][index] = 0
  533. self.lora_a_stacked[1][index] = 0
  534. self.lora_b_stacked[1][index] = 0
  535. self.lora_a_stacked[2][index] = 0
  536. self.lora_b_stacked[2][index] = 0
  537. def set_lora(
  538. self,
  539. index: int,
  540. lora_a: torch.Tensor,
  541. lora_b: torch.Tensor,
  542. embeddings_tensor: Optional[torch.Tensor],
  543. ):
  544. self.reset_lora(index)
  545. if self.tp_size > 1:
  546. if lora_b[0] is not None:
  547. lora_b_q = lora_b[0][:, self.q_proj_shard_size *
  548. self.q_shard_id:self.q_proj_shard_size *
  549. (self.q_shard_id + 1)]
  550. self.lora_b_stacked[0][
  551. index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
  552. lora_b_q.T, non_blocking=True)
  553. if lora_b[1] is not None:
  554. lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
  555. self.kv_shard_id:self.kv_proj_shard_size *
  556. (self.kv_shard_id + 1)]
  557. self.lora_b_stacked[1][
  558. index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
  559. lora_b_k.T, non_blocking=True)
  560. if lora_b[2] is not None:
  561. lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
  562. self.kv_shard_id:self.kv_proj_shard_size *
  563. (self.kv_shard_id + 1)]
  564. self.lora_b_stacked[2][
  565. index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
  566. lora_b_v.T, non_blocking=True)
  567. else:
  568. if lora_b[0] is not None:
  569. self.lora_b_stacked[0][
  570. index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
  571. lora_b[0].T, non_blocking=True)
  572. if lora_b[1] is not None:
  573. self.lora_b_stacked[1][
  574. index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
  575. lora_b[1].T, non_blocking=True)
  576. if lora_b[2] is not None:
  577. self.lora_b_stacked[2][
  578. index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
  579. lora_b[2].T, non_blocking=True)
  580. if lora_a[0] is not None:
  581. self.lora_a_stacked[0][
  582. index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
  583. lora_a[0].T, non_blocking=True)
  584. if lora_a[1] is not None:
  585. self.lora_a_stacked[1][
  586. index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
  587. lora_a[1].T, non_blocking=True)
  588. if lora_a[2] is not None:
  589. self.lora_a_stacked[2][
  590. index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
  591. lora_a[2].T, non_blocking=True)
  592. def apply_weights(self, x: torch.Tensor,
  593. bias: Optional[torch.Tensor]) -> torch.Tensor:
  594. output = self.base_layer.linear_method.apply_weights(
  595. self.base_layer.linear_weights, x, bias)
  596. _apply_lora_packed_nslice(
  597. x,
  598. self.lora_a_stacked,
  599. self.lora_b_stacked,
  600. self.indices[:self.indices_len[0]],
  601. output,
  602. self.output_slices,
  603. )
  604. return output
  605. class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
  606. def __init__(self, base_layer: RowParallelLinear) -> None:
  607. super().__init__()
  608. self.base_layer = base_layer
  609. def create_lora_weights(
  610. self,
  611. max_loras: int,
  612. lora_config: LoRAConfig,
  613. model_config: Optional[PretrainedConfig] = None) -> None:
  614. device = _get_lora_device(self.base_layer)
  615. self.lora_a_stacked = torch.zeros(
  616. (
  617. max_loras,
  618. 1,
  619. lora_config.max_lora_rank,
  620. self.base_layer.input_size,
  621. ),
  622. dtype=lora_config.lora_dtype,
  623. device=device,
  624. )
  625. self.lora_b_stacked = torch.zeros(
  626. (
  627. max_loras,
  628. 1,
  629. self.base_layer.output_size,
  630. lora_config.max_lora_rank,
  631. ),
  632. dtype=lora_config.lora_dtype,
  633. device=device,
  634. )
  635. self.indices: Optional[torch.Tensor] = None
  636. self.indices_len: Optional[List[int]] = None
  637. def reset_lora(self, index: int):
  638. self.lora_a_stacked[index] = 0
  639. self.lora_b_stacked[index] = 0
  640. def set_lora(
  641. self,
  642. index: int,
  643. lora_a: torch.Tensor,
  644. lora_b: torch.Tensor,
  645. embeddings_tensor: Optional[torch.Tensor],
  646. ):
  647. self.reset_lora(index)
  648. if self.base_layer.tp_size > 1:
  649. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  650. shard_size = self.base_layer.weight.shape[1]
  651. start_idx = tensor_model_parallel_rank * shard_size
  652. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  653. lora_a = lora_a[start_idx:end_idx, :]
  654. self.lora_a_stacked[index,
  655. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  656. lora_a.T, non_blocking=True)
  657. self.lora_b_stacked[index,
  658. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  659. lora_b.T, non_blocking=True)
  660. def set_mapping(
  661. self,
  662. base_indices: torch.Tensor,
  663. sampler_indices: torch.Tensor,
  664. sampler_indices_padded: torch.Tensor,
  665. embeddings_indices: torch.Tensor,
  666. indices_len: List[int],
  667. ):
  668. self.indices = base_indices
  669. self.indices_len = indices_len
  670. def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
  671. output = self.base_layer.linear_method.apply_weights(
  672. self.base_layer.linear_weights, x)
  673. _apply_lora(
  674. x,
  675. self.lora_a_stacked,
  676. self.lora_b_stacked,
  677. self.indices[:self.indices_len[0]],
  678. output,
  679. )
  680. return output
  681. def forward(self, input_):
  682. """Forward of RowParallelLinear
  683. Args:
  684. input_: tensor whose last dimension is `input_size`. If
  685. `input_is_parallel` is set, then the last dimension
  686. is `input_size // tp_size`.
  687. Returns:
  688. - output
  689. - bias
  690. """
  691. # Set up backprop all-reduce.
  692. if self.base_layer.input_is_parallel:
  693. input_parallel = input_
  694. else:
  695. # TODO: simplify code below
  696. tp_rank = get_tensor_model_parallel_rank()
  697. splitted_input = split_tensor_along_last_dim(
  698. input_, num_partitions=self.base_layer.tp_size)
  699. input_parallel = splitted_input[tp_rank].contiguous()
  700. # Matrix multiply.
  701. output_parallel = self.apply_weights(input_parallel)
  702. if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
  703. output_ = tensor_model_parallel_all_reduce(output_parallel)
  704. else:
  705. output_ = output_parallel
  706. if not self.base_layer.skip_bias_add:
  707. output = (output_ + self.base_layer.bias
  708. if self.base_layer.bias is not None else output_)
  709. output_bias = None
  710. else:
  711. output = output_
  712. output_bias = self.base_layer.bias
  713. return output, output_bias
  714. @property
  715. def weight(self):
  716. return self.base_layer.weight
  717. class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
  718. def __init__(
  719. self,
  720. base_layer: LogitsProcessor,
  721. hidden_size: int,
  722. dtype: torch.dtype,
  723. device: torch.device,
  724. ) -> None:
  725. super().__init__()
  726. self.base_layer = base_layer
  727. self.hidden_size = hidden_size
  728. self.dtype = dtype
  729. self.device = device
  730. @property
  731. def logits_as_input(self):
  732. return self.base_layer.logits_as_input
  733. @property
  734. def vocab_size(self):
  735. return self.base_layer.vocab_size
  736. @property
  737. def scale(self):
  738. return self.base_layer.scale
  739. @property
  740. def org_vocab_size(self):
  741. return self.base_layer.org_vocab_size
  742. @property
  743. def include_gpu_probs_tensor(self):
  744. return self.base_layer.include_gpu_probs_tensor
  745. def create_lora_weights(
  746. self,
  747. max_loras: int,
  748. lora_config: LoRAConfig,
  749. model_config: Optional[PretrainedConfig] = None,
  750. ) -> None:
  751. # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
  752. if 32000 < self.base_layer.vocab_size > 33024:
  753. raise ValueError(
  754. "When using LoRA, vocab size must be 32000 >= vocab_size "
  755. "<= 33024")
  756. self.lora_a_stacked = torch.zeros(
  757. (
  758. max_loras,
  759. 1,
  760. lora_config.max_lora_rank,
  761. self.hidden_size,
  762. ),
  763. dtype=lora_config.lora_dtype,
  764. device=self.device,
  765. )
  766. self.lora_b_stacked = torch.zeros(
  767. (
  768. max_loras,
  769. 1,
  770. # Pad for kernel compatibility
  771. math.ceil(self.base_layer.vocab_size /
  772. lora_config.lora_vocab_padding_size) *
  773. lora_config.lora_vocab_padding_size,
  774. lora_config.max_lora_rank,
  775. ),
  776. dtype=lora_config.lora_dtype,
  777. device=self.device,
  778. )
  779. self.embeddings_tensors = torch.full(
  780. (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
  781. fill_value=float("-inf"),
  782. dtype=self.dtype,
  783. device=self.device,
  784. )
  785. self.indices = None
  786. self.indices_padded = None
  787. self.indices_len = None
  788. def reset_lora(self, index: int):
  789. self.lora_a_stacked[index] = 0
  790. self.lora_b_stacked[index] = 0
  791. self.embeddings_tensors[index] = float("-inf")
  792. def set_lora(
  793. self,
  794. index: int,
  795. lora_a: torch.Tensor,
  796. lora_b: torch.Tensor,
  797. embeddings_tensor: Optional[torch.Tensor],
  798. ):
  799. self.reset_lora(index)
  800. self.lora_a_stacked[index,
  801. 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
  802. lora_a.T, non_blocking=True)
  803. self.lora_b_stacked[index,
  804. 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
  805. lora_b.T, non_blocking=True)
  806. if embeddings_tensor is not None:
  807. self.embeddings_tensors[
  808. index, :embeddings_tensor.shape[0], :embeddings_tensor.
  809. shape[1], ] = embeddings_tensor
  810. def set_mapping(
  811. self,
  812. base_indices: torch.Tensor,
  813. sampler_indices: torch.Tensor,
  814. sampler_indices_padded: torch.Tensor,
  815. embeddings_indices: torch.Tensor,
  816. indices_len: List[int],
  817. ):
  818. self.indices = sampler_indices
  819. self.indices_padded = sampler_indices_padded
  820. self.indices_len = indices_len
  821. def _get_logits(
  822. self,
  823. hidden_states: torch.Tensor,
  824. embedding: torch.Tensor,
  825. embedding_bias: Optional[torch.Tensor] = None,
  826. ) -> torch.Tensor:
  827. # Get the logits for the next tokens.
  828. logits = torch.matmul(hidden_states, embedding.t())
  829. if embedding_bias is not None:
  830. logits += embedding_bias
  831. logits = tensor_model_parallel_gather(logits)
  832. if logits is None:
  833. return None
  834. lora_logits = torch.empty(
  835. self.embeddings_tensors.shape[0] + 1,
  836. self.embeddings_tensors.shape[1],
  837. hidden_states.shape[0],
  838. dtype=self.embeddings_tensors.dtype,
  839. device=self.embeddings_tensors.device,
  840. )
  841. torch.matmul(self.embeddings_tensors,
  842. hidden_states.T,
  843. out=lora_logits[:-1])
  844. lora_logits[-1] = float("-inf")
  845. lora_logits = lora_logits.mT
  846. lora_logits = (lora_logits.reshape(
  847. lora_logits.shape[0] * lora_logits.shape[1],
  848. lora_logits.shape[2],
  849. ).index_select(0,
  850. self.indices_padded[:self.indices_len[2]]).nan_to_num_(
  851. nan=float("-inf"),
  852. posinf=float("inf"),
  853. neginf=float("-inf")))
  854. logits[:,
  855. self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
  856. lora_logits.shape[1]] = lora_logits
  857. _apply_lora(
  858. hidden_states,
  859. self.lora_a_stacked,
  860. self.lora_b_stacked,
  861. self.indices[:self.indices_len[1]],
  862. logits,
  863. )
  864. # Remove paddings in vocab (if any).
  865. logits = logits[:, :self.base_layer.vocab_size]
  866. return logits
  867. def forward(self, *args, **kwargs):
  868. return type(self.base_layer).forward(self, *args, **kwargs)
  869. def from_layer(
  870. layer: nn.Module,
  871. max_loras: int,
  872. lora_config: LoRAConfig,
  873. model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
  874. supported_layer_types = {
  875. VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
  876. ColumnParallelLinear: ColumnParallelLinearWithLoRA,
  877. QKVParallelLinear: QKVParallelLinearWithLora,
  878. MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
  879. RowParallelLinear: RowParallelLinearWithLoRA,
  880. }
  881. for src_layer_type, lora_layer_type in supported_layer_types.items():
  882. if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
  883. ret = lora_layer_type(layer)
  884. ret.create_lora_weights(max_loras, lora_config, model_config)
  885. return ret
  886. return layer
  887. def from_layer_logits_processor(
  888. layer: LogitsProcessor,
  889. lm_head: ParallelLMHead,
  890. max_loras: int,
  891. lora_config: LoRAConfig,
  892. model_config: Optional[PretrainedConfig] = None,
  893. ) -> LogitsProcessorWithLoRA:
  894. ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
  895. lm_head.weight.dtype, lm_head.weight.device)
  896. ret.create_lora_weights(max_loras, lora_config, model_config)
  897. return ret