punica.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. """
  2. Based on:
  3. Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
  4. Punica: Multi-Tenant LoRA Serving.
  5. https://arxiv.org/abs/2310.18547
  6. """
  7. from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
  8. import torch
  9. from aphrodite.triton_utils import HAS_TRITON
  10. if HAS_TRITON:
  11. from aphrodite.lora.ops.bgmv_embed import bgmv_embed
  12. from aphrodite.lora.ops.bgmv_expand import bgmv_expand
  13. from aphrodite.lora.ops.bgmv_expand_slice import bgmv_expand_slice
  14. from aphrodite.lora.ops.bgmv_sample import bgmv_sample
  15. from aphrodite.lora.ops.bgmv_shrink import bgmv_shrink
  16. from aphrodite.lora.ops.sgmv_expand import sgmv_expand
  17. from aphrodite.lora.ops.sgmv_expand_slice import sgmv_expand_slice
  18. from aphrodite.lora.ops.sgmv_shrink import sgmv_shrink
  19. if TYPE_CHECKING:
  20. # avoid circuit import
  21. from aphrodite.lora.layers import LoRAMapping
  22. from aphrodite.lora.models import LongContextLoRAContext
  23. def compute_meta(
  24. token_lora_tensor: torch.Tensor
  25. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
  26. """
  27. Get the information required for the sgmv kernel. With the features:
  28. 1. If consecutive requests in the batch use the same LoRA, this function
  29. will combine them into a single request, improving sgmv kernel inference
  30. performance.
  31. 2. At the beginning of each prefill stage inference, recalculations are
  32. needed based on the input, but only once.
  33. """
  34. lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
  35. token_lora_tensor, return_counts=True)
  36. cum_result = torch.cumsum(seq_length_tensor, dim=0)
  37. b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
  38. b_seq_start_tensor[1:].copy_(cum_result[:-1])
  39. max_length = seq_length_tensor.max().item()
  40. token_nums = seq_length_tensor.sum().item()
  41. batch_size = lora_indices_tensor.size(0)
  42. no_lora = False
  43. # -1 means no lora should be applied. Use `no_lora` to determine whether
  44. # the current step requires LoRA. If LoRA is not needed, the prefill stage
  45. # does not need to launch the triton kernel, which can improve performance
  46. if batch_size == 1 and lora_indices_tensor == -1:
  47. no_lora = True
  48. return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
  49. batch_size, max_length, token_nums, no_lora)
  50. # TODO see if this can be vectorized
  51. def convert_mapping(
  52. mapping: "LoRAMapping",
  53. lora_index_to_id: List[Optional[int]],
  54. max_loras: int,
  55. vocab_size: int,
  56. extra_vocab_size: int,
  57. long_lora_context: Optional["LongContextLoRAContext"] = None,
  58. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
  59. Optional[torch.Tensor], List[int]]:
  60. """Converts LoRAMapping to index tensors.
  61. Args:
  62. mapping: LoRAMapping mapping rows in a batch to LoRA ids.
  63. lora_index_to_id: List mapping LoRA ids to LoRA indices.
  64. max_loras: Maximum number of LoRAs.
  65. vocab_size: Model vocab size.
  66. extra_vocab_size: Extra vocab size each LoRA can have.
  67. long_lora_context: Passed if there are long context lora in a batch.
  68. Returns:
  69. A tuple of tensors:
  70. base_indices: Tensor of shape [batch_size] mapping batch rows to
  71. LoRA indices.
  72. sampler_indices: Tensor of shape [batch_size] mapping requests to
  73. LoRA indices for sampler. For generation, this will be the
  74. same as base_indicies. For prefill, this will map requests
  75. to LoRA indices.
  76. sampler_indices_padded: Tensor of shape [batch_size] mapping
  77. requests to LoRA indices for sampler with padding.
  78. Same as sampler_indicies, but -1 is replaced with
  79. max_loras.
  80. embeddings_indices: Tensor of shape [2, batch_size] mapping
  81. requests to embedding indices. First row is for embeddings
  82. added by the LoRAs, second row is for the LoRA.lora_a
  83. embeddings.
  84. long_lora_indices: Tensor of shape [batch_size] mapping
  85. requests to RoPE offsets and rot dims for long LoRAs.
  86. None if long context lora doesn't exist.
  87. indices_len: List of lengths of the above tensors. It contains
  88. (base_indices, sampler_indices, sampler_indices_padded,
  89. embeddings_indices, long_lora_indices).
  90. """
  91. index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
  92. embedding_indices = index_mapping_indices.copy()
  93. lora_indices = index_mapping_indices.copy()
  94. long_lora_offsets: Optional[torch.Tensor] = None
  95. if long_lora_context:
  96. long_lora_offsets = torch.zeros(len(index_mapping_indices),
  97. device="cuda",
  98. dtype=torch.long)
  99. prompt_mapping: List[int] = [
  100. lora_index_to_id.index(x) if x > 0 else -1
  101. for x in mapping.prompt_mapping
  102. ]
  103. lora_idx = None
  104. for i in range(len(index_mapping_indices)):
  105. # TODO index can be slow. optimize
  106. lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
  107. if index_mapping_indices[i] > 0 else -1)
  108. embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
  109. lora_indices[i] = lora_idx
  110. if long_lora_context:
  111. assert long_lora_offsets is not None
  112. lora_offset: int = long_lora_context.offsets_by_lora_id.get(
  113. index_mapping_indices[i], 0)
  114. long_lora_offsets[i] = lora_offset
  115. indices_list: List[Union[List[int], torch.Tensor]] = [
  116. index_mapping_indices,
  117. lora_indices,
  118. embedding_indices,
  119. ]
  120. if long_lora_context:
  121. assert long_lora_offsets is not None
  122. indices_list.append(long_lora_offsets)
  123. indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
  124. prompt_mapping_tensor = torch.tensor(prompt_mapping,
  125. device="cuda",
  126. dtype=torch.long)
  127. embeddings_indices = torch.stack([
  128. indices[2] * extra_vocab_size,
  129. indices[2] * (vocab_size + extra_vocab_size),
  130. ])
  131. embeddings_indices[embeddings_indices == -1] = max_loras - 1
  132. base_indices = indices[1]
  133. sampler_indices = prompt_mapping_tensor
  134. sampler_indices_padded = sampler_indices.clone()
  135. sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
  136. sampler_indices_padded = torch.arange(
  137. 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (
  138. sampler_indices_padded * len(sampler_indices_padded))
  139. long_lora_indices = None
  140. long_lora_indices_len: Optional[int] = None
  141. if long_lora_context:
  142. long_lora_indices = indices[3]
  143. long_lora_indices_len = long_lora_indices.shape[-1]
  144. # Contain length of indices tensors. Used to index into each tensor.
  145. indices_len = [
  146. base_indices.shape[-1],
  147. sampler_indices.shape[-1],
  148. sampler_indices_padded.shape[-1],
  149. embeddings_indices.shape[-1],
  150. ]
  151. if long_lora_indices_len is not None:
  152. indices_len.append(long_lora_indices_len)
  153. else:
  154. # If long_lora doesn't exist,append None
  155. indices_len.append(None)
  156. return (
  157. base_indices,
  158. sampler_indices,
  159. sampler_indices_padded,
  160. embeddings_indices,
  161. long_lora_indices,
  162. indices_len,
  163. )
  164. class PunicaWrapper:
  165. """
  166. PunicaWrapper is designed to manage and provide metadata for the punica
  167. kernel. The main function is to maintain the state information for
  168. Multi-LoRA, and to provide the interface for the punica kernel.
  169. """
  170. def __init__(self, max_num_batched_tokens: int, max_batches: int,
  171. device: str):
  172. self._token_lora_indices = torch.empty(max_num_batched_tokens,
  173. dtype=torch.long,
  174. device=device)
  175. self._sampler_indices = torch.empty(max_num_batched_tokens,
  176. dtype=torch.long,
  177. device=device)
  178. self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
  179. dtype=torch.long,
  180. device=device)
  181. self._embeddings_indices = torch.empty(2,
  182. max_num_batched_tokens,
  183. dtype=torch.long,
  184. device=device)
  185. self._long_lora_indices = torch.empty(max_num_batched_tokens,
  186. dtype=torch.long,
  187. device=device)
  188. # 5 is the number of indicies tensors.
  189. # base_indices, sampler_indices, sampler_indices_padded,
  190. # embeddings_indices,long_lora_indices
  191. self.indices_len: List[Optional[int]] = [None] * 5
  192. # these attributes are the information required for sgmv kernel
  193. self._seq_start_locs = torch.empty(max_batches,
  194. dtype=torch.long,
  195. device=device)
  196. self._seq_lengths = torch.empty(max_batches,
  197. dtype=torch.long,
  198. device=device)
  199. self._lora_indices_per_batch = torch.empty(max_batches,
  200. dtype=torch.long,
  201. device=device)
  202. self.max_length: int = 0
  203. self.token_nums: int = 0
  204. self.batch_size: int = -1
  205. self.is_prefill = False
  206. self.no_lora = False
  207. def update_metadata(
  208. self,
  209. mapping: "LoRAMapping",
  210. lora_index_to_id: List[Optional[int]],
  211. max_loras: int,
  212. vocab_size: int,
  213. extra_vocab_size: int,
  214. long_lora_context: Optional["LongContextLoRAContext"] = None,
  215. ):
  216. self._update_base_metadata(mapping, lora_index_to_id, max_loras,
  217. vocab_size, extra_vocab_size,
  218. long_lora_context)
  219. if mapping.is_prefill:
  220. # Update metadata required for prefill-related operators.
  221. self._update_prefill_metada(self.token_lora_indices)
  222. self.is_prefill = True
  223. else:
  224. self.is_prefill = False
  225. def _update_base_metadata(
  226. self,
  227. mapping: "LoRAMapping",
  228. lora_index_to_id: List[Optional[int]],
  229. max_loras: int,
  230. vocab_size: int,
  231. extra_vocab_size: int,
  232. long_lora_context: Optional["LongContextLoRAContext"] = None,
  233. ):
  234. (
  235. base_indices,
  236. sampler_indices,
  237. sampler_indices_padded,
  238. embeddings_indices,
  239. long_lora_offsets_tensor,
  240. indices_len,
  241. ) = convert_mapping(
  242. mapping,
  243. lora_index_to_id,
  244. max_loras,
  245. vocab_size,
  246. extra_vocab_size,
  247. long_lora_context,
  248. )
  249. self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
  250. self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
  251. self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
  252. sampler_indices_padded)
  253. self._embeddings_indices[:embeddings_indices.
  254. shape[0], :embeddings_indices.shape[1]].copy_(
  255. embeddings_indices)
  256. if long_lora_offsets_tensor is not None:
  257. self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
  258. long_lora_offsets_tensor)
  259. else:
  260. self._long_lora_indices.zero_()
  261. self.indices_len[:] = indices_len
  262. def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
  263. (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
  264. batch_size, max_length, token_nums,
  265. no_lora) = compute_meta(token_lora_tensor)
  266. self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
  267. b_seq_start_tensor)
  268. self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
  269. self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
  270. lora_indices_tensor)
  271. self.batch_size = batch_size
  272. self.max_length = max_length
  273. self.token_nums = token_nums
  274. self.no_lora = no_lora
  275. @property
  276. def prefill_metadata(
  277. self
  278. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
  279. """
  280. This property provides a convenient way to access the necessary
  281. metadata for prefill-related kernel computations.
  282. 1. seq_start_locs: Tensor of sequence start positions.
  283. 2. seq_lengths: Tensor of sequence lengths.
  284. 3. lora_indices_per_batch: Tensor of lora indices, and an index of
  285. -1 means no lora should be applied.
  286. 4. batch_size: Batch size after clustering identical lora indices.
  287. 5. max_length: The maximum sequence length in the batch.
  288. 6. token_nums: The token numbers in the batch.
  289. """
  290. return (self._seq_start_locs[:self.batch_size],
  291. self._seq_lengths[:self.batch_size],
  292. self._lora_indices_per_batch[:self.batch_size],
  293. self.batch_size, self.max_length, self.token_nums)
  294. @property
  295. def token_lora_indices(self) -> torch.Tensor:
  296. """
  297. This property provides the lora indices corresponding to each token
  298. in the batch. An index of -1 means no lora should be applied.
  299. """
  300. token_lora_len = self.indices_len[0]
  301. return self._token_lora_indices[:token_lora_len]
  302. @property
  303. def sampler_indices(self) -> torch.Tensor:
  304. """
  305. This property is used to access the lora indices specifically for
  306. LogitsProcessorWithLoRA
  307. """
  308. sampler_indices_len = self.indices_len[1]
  309. return self._sampler_indices[:sampler_indices_len]
  310. @property
  311. def sampler_indices_padded(self) -> torch.Tensor:
  312. """
  313. This property provides access to padded sampler indices
  314. """
  315. indices_padded_len = self.indices_len[2]
  316. return self._sampler_indices_padded[:indices_padded_len]
  317. @property
  318. def embeddings_indices(self) -> torch.Tensor:
  319. """
  320. This property provides access to the indices used for lora embeddings,
  321. specifically for VocabParallelEmbeddingWithLoRA
  322. """
  323. embeddings_indices_len = self.indices_len[3]
  324. return self._embeddings_indices[:, :embeddings_indices_len]
  325. @property
  326. def long_lora_indices(self) -> torch.Tensor:
  327. """
  328. This property provides access to the indices used for long context
  329. lora, specifically for LinearScalingRotaryEmbeddingWithLora
  330. """
  331. long_lora_len = self.indices_len[4]
  332. return self._long_lora_indices[:long_lora_len]
  333. def shrink_prefill(
  334. self,
  335. y: torch.Tensor,
  336. x: torch.Tensor,
  337. w_t_all: torch.Tensor,
  338. scale: float,
  339. ):
  340. #No LoRA request, so return directly
  341. if self.no_lora:
  342. return
  343. sgmv_shrink(
  344. x,
  345. w_t_all,
  346. y,
  347. *self.prefill_metadata,
  348. scale,
  349. )
  350. def shrink_decode(
  351. self,
  352. y: torch.Tensor,
  353. x: torch.Tensor,
  354. w_t_all: torch.Tensor,
  355. scale: float,
  356. ):
  357. bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
  358. def expand_prefill(
  359. self,
  360. y: torch.Tensor,
  361. x: torch.Tensor,
  362. w_t_all: torch.Tensor,
  363. add_input: bool,
  364. ):
  365. #No LoRA request, so return directly
  366. if self.no_lora:
  367. return
  368. sgmv_expand(
  369. x,
  370. w_t_all,
  371. y,
  372. *self.prefill_metadata,
  373. add_input,
  374. )
  375. def expand_decode(
  376. self,
  377. y: torch.Tensor,
  378. x: torch.Tensor,
  379. w_t_all: torch.Tensor,
  380. add_input: bool,
  381. ):
  382. bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
  383. def expand_slice_prefill(
  384. self,
  385. y: torch.Tensor,
  386. x: torch.Tensor,
  387. w_t_all: torch.Tensor,
  388. y_offset: Optional[int],
  389. y_slice_size: Optional[int],
  390. add_input: bool,
  391. ):
  392. #No LoRA request, so return directly
  393. if self.no_lora:
  394. return
  395. sgmv_expand_slice(
  396. x,
  397. w_t_all,
  398. y,
  399. *self.prefill_metadata,
  400. y_offset,
  401. y_slice_size,
  402. add_input,
  403. )
  404. def expand_slice_decode(
  405. self,
  406. y: torch.Tensor,
  407. x: torch.Tensor,
  408. w_t_all: torch.Tensor,
  409. y_offset: Optional[int],
  410. y_slice_size: Optional[int],
  411. add_input: bool,
  412. ):
  413. bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
  414. y_slice_size, add_input)
  415. def add_shrink(
  416. self,
  417. y: torch.Tensor,
  418. x: torch.Tensor,
  419. w_t_all: torch.Tensor,
  420. scale: float,
  421. ):
  422. """
  423. Perform the ` y+=x@w_t_all` computation, which is suitable for the
  424. GEMM of lora'a.
  425. When `is_prefill is` true, it indicates that it is currently the
  426. prefill stage, and the `shrink_prefill` function should be called.
  427. Otherwise, it is the decode stage, and the shrink_decode function
  428. should be called.
  429. """
  430. shrink_fun: Callable = (self.shrink_prefill
  431. if self.is_prefill else self.shrink_decode)
  432. shrink_fun(y, x, w_t_all, scale)
  433. def add_expand(
  434. self,
  435. y: torch.Tensor,
  436. x: torch.Tensor,
  437. w_t_all: torch.Tensor,
  438. add_input: bool = True,
  439. ):
  440. """
  441. Perform the ` y+=x@w_t_all` computation, which is suitable for the
  442. GEMM of lora'b.
  443. When `is_prefill` is true, it indicates that it is currently the
  444. prefill stage, and the `expand_prefill` function should be called.
  445. Otherwise, it is the decode stage, and the expand_decode function
  446. should be called.
  447. """
  448. expand_fun: Callable = (self.expand_prefill
  449. if self.is_prefill else self.expand_decode)
  450. expand_fun(y, x, w_t_all, add_input)
  451. def add_expand_slice(self,
  452. y: torch.Tensor,
  453. x: torch.Tensor,
  454. w_t_all: torch.Tensor,
  455. y_offset: Optional[int],
  456. y_slice_size: Optional[int],
  457. add_input: bool = True):
  458. """
  459. Similar to `add_expand`
  460. """
  461. expand_slice_fun: Callable = (self.expand_slice_prefill
  462. if self.is_prefill else
  463. self.expand_slice_decode)
  464. expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
  465. def add_lora(self,
  466. y: torch.Tensor,
  467. x: torch.Tensor,
  468. wa_t_all: torch.Tensor,
  469. wb_t_all: torch.Tensor,
  470. scale: float,
  471. y_offset: Optional[int] = None,
  472. y_slice_size: Optional[int] = None,
  473. *,
  474. buffer: Optional[torch.Tensor] = None) -> None:
  475. """
  476. Semantics:
  477. y[i] += (
  478. x[i].unsqueeze(0)
  479. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  480. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  481. * scale
  482. ).squeeze(0)
  483. Args:
  484. y (torch.Tensor): Output tensor. Will be changed in-place.
  485. x (torch.Tensor): Input tensor
  486. wa_t_all (torch.Tensor): lora_a's weight
  487. wb_t_all (torch.Tensor): lora_b's weight
  488. scale (float): Scaling factor.
  489. y_offset (Optional[int], optional): Offset to apply to the starting
  490. column of y.
  491. y_slice_size (Optional[int], optional): Size of the y column slice..
  492. buffer (Optional[torch.Tensor], optional): Defaults to None.
  493. """
  494. y_org = y
  495. y = y.view(-1, y.shape[-1])
  496. x = x.view(-1, x.shape[-1])
  497. r = wb_t_all.size(-1)
  498. if buffer is None:
  499. # We set the buffer to be float32 by default ,refer to:
  500. # https://github.com/triton-lang/triton/issues/1387
  501. buffer = torch.zeros((x.size(0), r),
  502. dtype=torch.float32,
  503. device=x.device)
  504. self.add_shrink(buffer, x, wa_t_all, scale)
  505. if y_offset is None and y_slice_size is None:
  506. self.add_expand(y, buffer, wb_t_all, add_input=True)
  507. else:
  508. self.add_expand_slice(y,
  509. buffer,
  510. wb_t_all,
  511. y_offset,
  512. y_slice_size,
  513. add_input=True)
  514. y = y.view_as(y_org)
  515. def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
  516. lora_a_stacked: Tuple[torch.Tensor,
  517. torch.Tensor,
  518. torch.Tensor],
  519. lora_b_stacked: Tuple[torch.Tensor,
  520. torch.Tensor,
  521. torch.Tensor],
  522. scale: float,
  523. output_slices: Tuple[int, ...]) -> None:
  524. """
  525. Applies lora to each input. Similar to add_lora, This method is
  526. used for layers that are composed of multiple sublayers
  527. (slices) packed together.
  528. """
  529. y_org = y
  530. x = x.view(-1, x.shape[-1])
  531. y = y.view(-1, y.shape[-1])
  532. offset_left = 0
  533. # TODO fuse these kernels
  534. for slice_idx in range(len(output_slices)):
  535. self.add_lora(y, x, lora_a_stacked[slice_idx],
  536. lora_b_stacked[slice_idx], scale, offset_left,
  537. output_slices[slice_idx])
  538. offset_left += output_slices[slice_idx]
  539. y = y.view_as(y_org)
  540. def add_lora_logits(self,
  541. y: torch.Tensor,
  542. x: torch.Tensor,
  543. wa_t_all: torch.Tensor,
  544. wb_t_all: torch.Tensor,
  545. scale,
  546. *,
  547. buffer: Optional[torch.Tensor] = None) -> None:
  548. """
  549. LogitsProcessorWithLoRA always using bgmv
  550. """
  551. y_org = y
  552. y = y.view(-1, y.shape[-1])
  553. x = x.view(-1, x.shape[-1])
  554. r = wb_t_all.size(-1)
  555. if buffer is None:
  556. # We set the buffer to be float32 by default ,refer to:
  557. # https://github.com/triton-lang/triton/issues/1387
  558. buffer = torch.zeros((x.size(0), r),
  559. dtype=torch.float32,
  560. device=x.device)
  561. bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale)
  562. bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True)
  563. y = y.view_as(y_org)
  564. def bgmv_sample(self, hidden_states: torch.Tensor,
  565. lm_heads_all: torch.Tensor, lm_head_base: torch.Tensor):
  566. '''
  567. hidden_states - [num_tokens, hidden_dim]
  568. lm_heads_all - [num_loras, vocab_size, hidden_dim]
  569. the same as:
  570. vocab_size=self.lm_head_tensors.shape[-2]
  571. hidden_dim=hidden_states.size(0)
  572. logits = torch.zeros((hidden_dim, vocab_size),
  573. dtype=torch.float32,
  574. device=hidden_states.device)
  575. for i in range(len(hidden_states)):
  576. if indices[i]==-1:
  577. logits[i]=lm_head_base @ hidden_states[i]
  578. else:
  579. logits[i]=self.lm_head_tensors[indices[i]] @ hidden_states[i]
  580. '''
  581. indices = self.sampler_indices
  582. logits = bgmv_sample(hidden_states, lm_heads_all, lm_head_base,
  583. indices)
  584. return logits
  585. def bgmv_embedding(self, tokens: torch.LongTensor,
  586. embed_tokens_all: torch.Tensor,
  587. embed_tokens_base: torch.Tensor) -> torch.Tensor:
  588. '''
  589. embed_tokens_all - [num_loras, vocab_size, hidden_dim]
  590. modules_to_save embeddings
  591. embed_tokens_base - [vocab_size, hidden_dim] - base layer
  592. embeddings will be applied to tokens with index=-1
  593. tokens - [num_tokens]
  594. returns:
  595. embeddings: [num_tokens, hidden_dim]
  596. '''
  597. embeddings = bgmv_embed(tokens,
  598. embed_tokens_all,
  599. embed_tokens_base,
  600. token_indices=self.token_lora_indices.long())
  601. return embeddings