punica.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. """
  2. Based on:
  3. Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
  4. Punica: Multi-Tenant LoRA Serving.
  5. https://arxiv.org/abs/2310.18547
  6. """
  7. from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
  8. import torch
  9. from aphrodite.triton_utils import HAS_TRITON
  10. if HAS_TRITON:
  11. from aphrodite.lora.ops.bgmv_embed import bgmv_embed
  12. from aphrodite.lora.ops.bgmv_expand import bgmv_expand
  13. from aphrodite.lora.ops.bgmv_expand_slice import bgmv_expand_slice
  14. from aphrodite.lora.ops.bgmv_sample import bgmv_sample
  15. from aphrodite.lora.ops.bgmv_shrink import bgmv_shrink
  16. from aphrodite.lora.ops.sgmv_expand import sgmv_expand
  17. from aphrodite.lora.ops.sgmv_expand_slice import sgmv_expand_slice
  18. from aphrodite.lora.ops.sgmv_shrink import sgmv_shrink
  19. if TYPE_CHECKING:
  20. # avoid circuit import
  21. from aphrodite.lora.layers import LoRAMapping
  22. from aphrodite.lora.models import LongContextLoRAContext
  23. def compute_meta(
  24. token_lora_tensor: torch.Tensor
  25. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]:
  26. """
  27. Get the information required for the sgmv kernel. With the features:
  28. 1. If consecutive requests in the batch use the same LoRA, this function
  29. will combine them into a single request, improving sgmv kernel inference
  30. performance.
  31. 2. At the beginning of each prefill stage inference, recalculations are
  32. needed based on the input, but only once.
  33. """
  34. lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
  35. token_lora_tensor, return_counts=True)
  36. cum_result = torch.cumsum(seq_length_tensor, dim=0)
  37. b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
  38. b_seq_start_tensor[1:].copy_(cum_result[:-1])
  39. max_length = seq_length_tensor.max().item()
  40. batch_size = lora_indices_tensor.size(0)
  41. no_lora = False
  42. # -1 means no lora should be applied. Use `no_lora` to determine whether
  43. # the current step requires LoRA. If LoRA is not needed, the prefill stage
  44. # does not need to launch the triton kernel, which can improve performance
  45. if batch_size == 1 and lora_indices_tensor == -1:
  46. no_lora = True
  47. return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
  48. batch_size, max_length, no_lora)
  49. # TODO see if this can be vectorized
  50. def convert_mapping(
  51. mapping: "LoRAMapping",
  52. lora_index_to_id: List[Optional[int]],
  53. max_loras: int,
  54. vocab_size: int,
  55. extra_vocab_size: int,
  56. long_lora_context: Optional["LongContextLoRAContext"] = None,
  57. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
  58. Optional[torch.Tensor], List[int]]:
  59. """Converts LoRAMapping to index tensors.
  60. Args:
  61. mapping: LoRAMapping mapping rows in a batch to LoRA ids.
  62. lora_index_to_id: List mapping LoRA ids to LoRA indices.
  63. max_loras: Maximum number of LoRAs.
  64. vocab_size: Model vocab size.
  65. extra_vocab_size: Extra vocab size each LoRA can have.
  66. long_lora_context: Passed if there are long context lora in a batch.
  67. Returns:
  68. A tuple of tensors:
  69. base_indices: Tensor of shape [batch_size] mapping batch rows to
  70. LoRA indices.
  71. sampler_indices: Tensor of shape [batch_size] mapping requests to
  72. LoRA indices for sampler. For generation, this will be the
  73. same as base_indicies. For prefill, this will map requests
  74. to LoRA indices.
  75. sampler_indices_padded: Tensor of shape [batch_size] mapping
  76. requests to LoRA indices for sampler with padding.
  77. Same as sampler_indicies, but -1 is replaced with
  78. max_loras.
  79. embeddings_indices: Tensor of shape [2, batch_size] mapping
  80. requests to embedding indices. First row is for embeddings
  81. added by the LoRAs, second row is for the LoRA.lora_a
  82. embeddings.
  83. long_lora_indices: Tensor of shape [batch_size] mapping
  84. requests to RoPE offsets and rot dims for long LoRAs.
  85. None if long context lora doesn't exist.
  86. indices_len: List of lengths of the above tensors. It contains
  87. (base_indices, sampler_indices, sampler_indices_padded,
  88. embeddings_indices, long_lora_indices).
  89. """
  90. index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
  91. embedding_indices = index_mapping_indices.copy()
  92. lora_indices = index_mapping_indices.copy()
  93. long_lora_offsets: Optional[torch.Tensor] = None
  94. if long_lora_context:
  95. long_lora_offsets = torch.zeros(len(index_mapping_indices),
  96. device="cuda",
  97. dtype=torch.long)
  98. prompt_mapping: List[int] = [
  99. lora_index_to_id.index(x) if x > 0 else -1
  100. for x in mapping.prompt_mapping
  101. ]
  102. lora_idx = None
  103. for i in range(len(index_mapping_indices)):
  104. # TODO index can be slow. optimize
  105. lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
  106. if index_mapping_indices[i] > 0 else -1)
  107. embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
  108. lora_indices[i] = lora_idx
  109. if long_lora_context:
  110. assert long_lora_offsets is not None
  111. lora_offset: int = long_lora_context.offsets_by_lora_id.get(
  112. index_mapping_indices[i], 0)
  113. long_lora_offsets[i] = lora_offset
  114. indices_list: List[Union[List[int], torch.Tensor]] = [
  115. index_mapping_indices,
  116. lora_indices,
  117. embedding_indices,
  118. ]
  119. if long_lora_context:
  120. assert long_lora_offsets is not None
  121. indices_list.append(long_lora_offsets)
  122. indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
  123. prompt_mapping_tensor = torch.tensor(prompt_mapping,
  124. device="cuda",
  125. dtype=torch.long)
  126. embeddings_indices = torch.stack([
  127. indices[2] * extra_vocab_size,
  128. indices[2] * (vocab_size + extra_vocab_size),
  129. ])
  130. embeddings_indices[embeddings_indices == -1] = max_loras - 1
  131. base_indices = indices[1]
  132. sampler_indices = prompt_mapping_tensor
  133. sampler_indices_padded = sampler_indices.clone()
  134. sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
  135. sampler_indices_padded = torch.arange(
  136. 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (
  137. sampler_indices_padded * len(sampler_indices_padded))
  138. long_lora_indices = None
  139. long_lora_indices_len: Optional[int] = None
  140. if long_lora_context:
  141. long_lora_indices = indices[3]
  142. long_lora_indices_len = long_lora_indices.shape[-1]
  143. # Contain length of indices tensors. Used to index into each tensor.
  144. indices_len = [
  145. base_indices.shape[-1],
  146. sampler_indices.shape[-1],
  147. sampler_indices_padded.shape[-1],
  148. embeddings_indices.shape[-1],
  149. ]
  150. if long_lora_indices_len is not None:
  151. indices_len.append(long_lora_indices_len)
  152. else:
  153. # If long_lora doesn't exist,append None
  154. indices_len.append(None)
  155. return (
  156. base_indices,
  157. sampler_indices,
  158. sampler_indices_padded,
  159. embeddings_indices,
  160. long_lora_indices,
  161. indices_len,
  162. )
  163. class PunicaWrapper:
  164. """
  165. PunicaWrapper is designed to manage and provide metadata for the punica
  166. kernel. The main function is to maintain the state information for
  167. Multi-LoRA, and to provide the interface for the punica kernel.
  168. """
  169. def __init__(self, max_num_batched_tokens: int, max_batches: int,
  170. device: str):
  171. self._token_lora_indices = torch.empty(max_num_batched_tokens,
  172. dtype=torch.long,
  173. device=device)
  174. self._sampler_indices = torch.empty(max_num_batched_tokens,
  175. dtype=torch.long,
  176. device=device)
  177. self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
  178. dtype=torch.long,
  179. device=device)
  180. self._embeddings_indices = torch.empty(2,
  181. max_num_batched_tokens,
  182. dtype=torch.long,
  183. device=device)
  184. self._long_lora_indices = torch.empty(max_num_batched_tokens,
  185. dtype=torch.long,
  186. device=device)
  187. # 5 is the number of indicies tensors.
  188. # base_indices, sampler_indices, sampler_indices_padded,
  189. # embeddings_indices,long_lora_indices
  190. self.indices_len: List[Optional[int]] = [None] * 5
  191. # these attributes are the information required for sgmv kernel
  192. self._seq_start_locs = torch.empty(max_batches,
  193. dtype=torch.long,
  194. device=device)
  195. self._seq_lengths = torch.empty(max_batches,
  196. dtype=torch.long,
  197. device=device)
  198. self._lora_indices_per_batch = torch.empty(max_batches,
  199. dtype=torch.long,
  200. device=device)
  201. self.max_length: int = 0
  202. self.batch_size: int = -1
  203. self.is_prefill = False
  204. self.no_lora = False
  205. def update_metadata(
  206. self,
  207. mapping: "LoRAMapping",
  208. lora_index_to_id: List[Optional[int]],
  209. max_loras: int,
  210. vocab_size: int,
  211. extra_vocab_size: int,
  212. long_lora_context: Optional["LongContextLoRAContext"] = None,
  213. ):
  214. self._update_base_metadata(mapping, lora_index_to_id, max_loras,
  215. vocab_size, extra_vocab_size,
  216. long_lora_context)
  217. if mapping.is_prefill:
  218. # Update metadata required for prefill-related operators.
  219. self._update_prefill_metada(self.token_lora_indices)
  220. self.is_prefill = True
  221. else:
  222. self.is_prefill = False
  223. def _update_base_metadata(
  224. self,
  225. mapping: "LoRAMapping",
  226. lora_index_to_id: List[Optional[int]],
  227. max_loras: int,
  228. vocab_size: int,
  229. extra_vocab_size: int,
  230. long_lora_context: Optional["LongContextLoRAContext"] = None,
  231. ):
  232. (
  233. base_indices,
  234. sampler_indices,
  235. sampler_indices_padded,
  236. embeddings_indices,
  237. long_lora_offsets_tensor,
  238. indices_len,
  239. ) = convert_mapping(
  240. mapping,
  241. lora_index_to_id,
  242. max_loras,
  243. vocab_size,
  244. extra_vocab_size,
  245. long_lora_context,
  246. )
  247. self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
  248. self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
  249. self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
  250. sampler_indices_padded)
  251. self._embeddings_indices[:embeddings_indices.
  252. shape[0], :embeddings_indices.shape[1]].copy_(
  253. embeddings_indices)
  254. if long_lora_offsets_tensor is not None:
  255. self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
  256. long_lora_offsets_tensor)
  257. else:
  258. self._long_lora_indices.zero_()
  259. self.indices_len[:] = indices_len
  260. def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
  261. (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
  262. batch_size, max_length, no_lora) = compute_meta(token_lora_tensor)
  263. self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
  264. b_seq_start_tensor)
  265. self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
  266. self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
  267. lora_indices_tensor)
  268. self.batch_size = batch_size
  269. self.max_length = max_length
  270. self.no_lora = no_lora
  271. @property
  272. def prefill_metadata(
  273. self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
  274. """
  275. This property provides a convenient way to access the necessary
  276. metadata for prefill-related kernel computations.
  277. 1. seq_start_locs: Tensor of sequence start positions
  278. 2. seq_lengths: Tensor of sequence lengths
  279. 3. lora_indices_per_batch: Tensor of lora indices, and an index of
  280. -1 means no lora should be applied.
  281. 4. batch_size: batch size after clustering identical lora indices
  282. 5. max_length: The maximum sequence length in the batch
  283. """
  284. return (self._seq_start_locs[:self.batch_size],
  285. self._seq_lengths[:self.batch_size],
  286. self._lora_indices_per_batch[:self.batch_size],
  287. self.batch_size, self.max_length)
  288. @property
  289. def token_lora_indices(self) -> torch.Tensor:
  290. """
  291. This property provides the lora indices corresponding to each token
  292. in the batch. An index of -1 means no lora should be applied.
  293. """
  294. token_lora_len = self.indices_len[0]
  295. return self._token_lora_indices[:token_lora_len]
  296. @property
  297. def sampler_indices(self) -> torch.Tensor:
  298. """
  299. This property is used to access the lora indices specifically for
  300. LogitsProcessorWithLoRA
  301. """
  302. sampler_indices_len = self.indices_len[1]
  303. return self._sampler_indices[:sampler_indices_len]
  304. @property
  305. def sampler_indices_padded(self) -> torch.Tensor:
  306. """
  307. This property provides access to padded sampler indices
  308. """
  309. indices_padded_len = self.indices_len[2]
  310. return self._sampler_indices_padded[:indices_padded_len]
  311. @property
  312. def embeddings_indices(self) -> torch.Tensor:
  313. """
  314. This property provides access to the indices used for lora embeddings,
  315. specifically for VocabParallelEmbeddingWithLoRA
  316. """
  317. embeddings_indices_len = self.indices_len[3]
  318. return self._embeddings_indices[:, :embeddings_indices_len]
  319. @property
  320. def long_lora_indices(self) -> torch.Tensor:
  321. """
  322. This property provides access to the indices used for long context
  323. lora, specifically for LinearScalingRotaryEmbeddingWithLora
  324. """
  325. long_lora_len = self.indices_len[4]
  326. return self._long_lora_indices[:long_lora_len]
  327. def shrink_prefill(
  328. self,
  329. y: torch.Tensor,
  330. x: torch.Tensor,
  331. w_t_all: torch.Tensor,
  332. scale: float,
  333. ):
  334. #No LoRA request, so return directly
  335. if self.no_lora:
  336. return
  337. sgmv_shrink(
  338. x,
  339. w_t_all,
  340. y,
  341. *self.prefill_metadata,
  342. scale,
  343. )
  344. def shrink_decode(
  345. self,
  346. y: torch.Tensor,
  347. x: torch.Tensor,
  348. w_t_all: torch.Tensor,
  349. scale: float,
  350. ):
  351. bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
  352. def expand_prefill(
  353. self,
  354. y: torch.Tensor,
  355. x: torch.Tensor,
  356. w_t_all: torch.Tensor,
  357. add_input: bool,
  358. ):
  359. #No LoRA request, so return directly
  360. if self.no_lora:
  361. return
  362. sgmv_expand(
  363. x,
  364. w_t_all,
  365. y,
  366. *self.prefill_metadata,
  367. add_input,
  368. )
  369. def expand_decode(
  370. self,
  371. y: torch.Tensor,
  372. x: torch.Tensor,
  373. w_t_all: torch.Tensor,
  374. add_input: bool,
  375. ):
  376. bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
  377. def expand_slice_prefill(
  378. self,
  379. y: torch.Tensor,
  380. x: torch.Tensor,
  381. w_t_all: torch.Tensor,
  382. y_offset: Optional[int],
  383. y_slice_size: Optional[int],
  384. add_input: bool,
  385. ):
  386. #No LoRA request, so return directly
  387. if self.no_lora:
  388. return
  389. sgmv_expand_slice(
  390. x,
  391. w_t_all,
  392. y,
  393. *self.prefill_metadata,
  394. y_offset,
  395. y_slice_size,
  396. add_input,
  397. )
  398. def expand_slice_decode(
  399. self,
  400. y: torch.Tensor,
  401. x: torch.Tensor,
  402. w_t_all: torch.Tensor,
  403. y_offset: Optional[int],
  404. y_slice_size: Optional[int],
  405. add_input: bool,
  406. ):
  407. bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
  408. y_slice_size, add_input)
  409. def add_shrink(
  410. self,
  411. y: torch.Tensor,
  412. x: torch.Tensor,
  413. w_t_all: torch.Tensor,
  414. scale: float,
  415. ):
  416. """
  417. Perform the ` y+=x@w_t_all` computation, which is suitable for the
  418. GEMM of lora'a.
  419. When `is_prefill is` true, it indicates that it is currently the
  420. prefill stage, and the `shrink_prefill` function should be called.
  421. Otherwise, it is the decode stage, and the shrink_decode function
  422. should be called.
  423. """
  424. shrink_fun: Callable = (self.shrink_prefill
  425. if self.is_prefill else self.shrink_decode)
  426. shrink_fun(y, x, w_t_all, scale)
  427. def add_expand(
  428. self,
  429. y: torch.Tensor,
  430. x: torch.Tensor,
  431. w_t_all: torch.Tensor,
  432. add_input: bool = True,
  433. ):
  434. """
  435. Perform the ` y+=x@w_t_all` computation, which is suitable for the
  436. GEMM of lora'b.
  437. When `is_prefill` is true, it indicates that it is currently the
  438. prefill stage, and the `expand_prefill` function should be called.
  439. Otherwise, it is the decode stage, and the expand_decode function
  440. should be called.
  441. """
  442. expand_fun: Callable = (self.expand_prefill
  443. if self.is_prefill else self.expand_decode)
  444. expand_fun(y, x, w_t_all, add_input)
  445. def add_expand_slice(self,
  446. y: torch.Tensor,
  447. x: torch.Tensor,
  448. w_t_all: torch.Tensor,
  449. y_offset: Optional[int],
  450. y_slice_size: Optional[int],
  451. add_input: bool = True):
  452. """
  453. Similar to `add_expand`
  454. """
  455. expand_slice_fun: Callable = (self.expand_slice_prefill
  456. if self.is_prefill else
  457. self.expand_slice_decode)
  458. expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
  459. def add_lora(self,
  460. y: torch.Tensor,
  461. x: torch.Tensor,
  462. wa_t_all: torch.Tensor,
  463. wb_t_all: torch.Tensor,
  464. scale: float,
  465. y_offset: Optional[int] = None,
  466. y_slice_size: Optional[int] = None,
  467. *,
  468. buffer: Optional[torch.Tensor] = None) -> None:
  469. """
  470. Semantics:
  471. y[i] += (
  472. x[i].unsqueeze(0)
  473. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  474. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  475. * scale
  476. ).squeeze(0)
  477. Args:
  478. y (torch.Tensor): Output tensor. Will be changed in-place.
  479. x (torch.Tensor): Input tensor
  480. wa_t_all (torch.Tensor): lora_a's weight
  481. wb_t_all (torch.Tensor): lora_b's weight
  482. scale (float): Scaling factor.
  483. y_offset (Optional[int], optional): Offset to apply to the starting
  484. column of y.
  485. y_slice_size (Optional[int], optional): Size of the y column slice..
  486. buffer (Optional[torch.Tensor], optional): Defaults to None.
  487. """
  488. y_org = y
  489. y = y.view(-1, y.shape[-1])
  490. x = x.view(-1, x.shape[-1])
  491. r = wb_t_all.size(-1)
  492. if buffer is None:
  493. # We set the buffer to be float32 by default ,refer to:
  494. # https://github.com/triton-lang/triton/issues/1387
  495. buffer = torch.zeros((x.size(0), r),
  496. dtype=torch.float32,
  497. device=x.device)
  498. self.add_shrink(buffer, x, wa_t_all, scale)
  499. if y_offset is None and y_slice_size is None:
  500. self.add_expand(y, buffer, wb_t_all, add_input=True)
  501. else:
  502. self.add_expand_slice(y,
  503. buffer,
  504. wb_t_all,
  505. y_offset,
  506. y_slice_size,
  507. add_input=True)
  508. y = y.view_as(y_org)
  509. def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
  510. lora_a_stacked: Tuple[torch.Tensor,
  511. torch.Tensor,
  512. torch.Tensor],
  513. lora_b_stacked: Tuple[torch.Tensor,
  514. torch.Tensor,
  515. torch.Tensor],
  516. scale: float,
  517. output_slices: Tuple[int, ...]) -> None:
  518. """
  519. Applies lora to each input. Similar to add_lora, This method is
  520. used for layers that are composed of multiple sublayers
  521. (slices) packed together.
  522. """
  523. y_org = y
  524. x = x.view(-1, x.shape[-1])
  525. y = y.view(-1, y.shape[-1])
  526. offset_left = 0
  527. # TODO fuse these kernels
  528. for slice_idx in range(len(output_slices)):
  529. self.add_lora(y, x, lora_a_stacked[slice_idx],
  530. lora_b_stacked[slice_idx], scale, offset_left,
  531. output_slices[slice_idx])
  532. offset_left += output_slices[slice_idx]
  533. y = y.view_as(y_org)
  534. def add_lora_logits(self,
  535. y: torch.Tensor,
  536. x: torch.Tensor,
  537. wa_t_all: torch.Tensor,
  538. wb_t_all: torch.Tensor,
  539. scale,
  540. *,
  541. buffer: Optional[torch.Tensor] = None) -> None:
  542. """
  543. LogitsProcessorWithLoRA always using bgmv
  544. """
  545. y_org = y
  546. y = y.view(-1, y.shape[-1])
  547. x = x.view(-1, x.shape[-1])
  548. r = wb_t_all.size(-1)
  549. if buffer is None:
  550. # We set the buffer to be float32 by default ,refer to:
  551. # https://github.com/triton-lang/triton/issues/1387
  552. buffer = torch.zeros((x.size(0), r),
  553. dtype=torch.float32,
  554. device=x.device)
  555. bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale)
  556. bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True)
  557. y = y.view_as(y_org)
  558. def bgmv_sample(self, hidden_states: torch.Tensor,
  559. lm_heads_all: torch.Tensor, lm_head_base: torch.Tensor):
  560. '''
  561. hidden_states - [num_tokens, hidden_dim]
  562. lm_heads_all - [num_loras, vocab_size, hidden_dim]
  563. the same as:
  564. vocab_size=self.lm_head_tensors.shape[-2]
  565. hidden_dim=hidden_states.size(0)
  566. logits = torch.zeros((hidden_dim, vocab_size),
  567. dtype=torch.float32,
  568. device=hidden_states.device)
  569. for i in range(len(hidden_states)):
  570. if indices[i]==-1:
  571. logits[i]=lm_head_base @ hidden_states[i]
  572. else:
  573. logits[i]=self.lm_head_tensors[indices[i]] @ hidden_states[i]
  574. '''
  575. indices = self.sampler_indices
  576. logits = bgmv_sample(hidden_states, lm_heads_all, lm_head_base,
  577. indices)
  578. return logits
  579. def bgmv_embedding(self, tokens: torch.LongTensor,
  580. embed_tokens_all: torch.Tensor,
  581. embed_tokens_base: torch.Tensor) -> torch.Tensor:
  582. '''
  583. embed_tokens_all - [num_loras, vocab_size, hidden_dim]
  584. modules_to_save embeddings
  585. embed_tokens_base - [vocab_size, hidden_dim] - base layer
  586. embeddings will be applied to tokens with index=-1
  587. tokens - [num_tokens]
  588. returns:
  589. embeddings: [num_tokens, hidden_dim]
  590. '''
  591. embeddings = bgmv_embed(tokens,
  592. embed_tokens_all,
  593. embed_tokens_base,
  594. token_indices=self.token_lora_indices.long())
  595. return embeddings