1
0

punica.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. """
  2. Based on:
  3. Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
  4. Punica: Multi-Tenant LoRA Serving.
  5. https://arxiv.org/abs/2310.18547
  6. """
  7. from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
  8. import torch
  9. from aphrodite.triton_utils import HAS_TRITON
  10. if HAS_TRITON:
  11. from aphrodite.lora.ops.bgmv_expand import bgmv_expand
  12. from aphrodite.lora.ops.bgmv_expand_slice import bgmv_expand_slice
  13. from aphrodite.lora.ops.bgmv_shrink import bgmv_shrink
  14. from aphrodite.lora.ops.sgmv_expand import sgmv_expand
  15. from aphrodite.lora.ops.sgmv_expand_slice import sgmv_expand_slice
  16. from aphrodite.lora.ops.sgmv_shrink import sgmv_shrink
  17. if TYPE_CHECKING:
  18. # avoid circuit import
  19. from aphrodite.lora.layers import LoRAMapping
  20. from aphrodite.lora.models import LongContextLoRAContext
  21. def compute_meta(
  22. token_lora_tensor: torch.Tensor
  23. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]:
  24. """
  25. Get the information required for the sgmv kernel. With the features:
  26. 1. If consecutive requests in the batch use the same LoRA, this function
  27. will combine them into a single request, improving sgmv kernel inference
  28. performance.
  29. 2. At the beginning of each prefill stage inference, recalculations are
  30. needed based on the input, but only once.
  31. """
  32. lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
  33. token_lora_tensor, return_counts=True)
  34. cum_result = torch.cumsum(seq_length_tensor, dim=0)
  35. b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
  36. b_seq_start_tensor[1:].copy_(cum_result[:-1])
  37. max_length = seq_length_tensor.max().item()
  38. batch_size = lora_indices_tensor.size(0)
  39. no_lora = False
  40. # -1 means no lora should be applied. Use `no_lora` to determine whether
  41. # the current step requires LoRA. If LoRA is not needed, the prefill stage
  42. # does not need to launch the triton kernel, which can improve performance
  43. if batch_size == 1 and lora_indices_tensor == -1:
  44. no_lora = True
  45. return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
  46. batch_size, max_length, no_lora)
  47. # TODO see if this can be vectorized
  48. def convert_mapping(
  49. mapping: "LoRAMapping",
  50. lora_index_to_id: List[Optional[int]],
  51. max_loras: int,
  52. vocab_size: int,
  53. extra_vocab_size: int,
  54. long_lora_context: Optional["LongContextLoRAContext"] = None,
  55. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
  56. Optional[torch.Tensor], List[int]]:
  57. """Converts LoRAMapping to index tensors.
  58. Args:
  59. mapping: LoRAMapping mapping rows in a batch to LoRA ids.
  60. lora_index_to_id: List mapping LoRA ids to LoRA indices.
  61. max_loras: Maximum number of LoRAs.
  62. vocab_size: Model vocab size.
  63. extra_vocab_size: Extra vocab size each LoRA can have.
  64. long_lora_context: Passed if there are long context lora in a batch.
  65. Returns:
  66. A tuple of tensors:
  67. base_indices: Tensor of shape [batch_size] mapping batch rows to
  68. LoRA indices.
  69. sampler_indices: Tensor of shape [batch_size] mapping requests to
  70. LoRA indices for sampler. For generation, this will be the
  71. same as base_indicies. For prefill, this will map requests
  72. to LoRA indices.
  73. sampler_indices_padded: Tensor of shape [batch_size] mapping
  74. requests to LoRA indices for sampler with padding.
  75. Same as sampler_indicies, but -1 is replaced with
  76. max_loras.
  77. embeddings_indices: Tensor of shape [2, batch_size] mapping
  78. requests to embedding indices. First row is for embeddings
  79. added by the LoRAs, second row is for the LoRA.lora_a
  80. embeddings.
  81. long_lora_indices: Tensor of shape [batch_size] mapping
  82. requests to RoPE offsets and rot dims for long LoRAs.
  83. None if long context lora doesn't exist.
  84. indices_len: List of lengths of the above tensors. It contains
  85. (base_indices, sampler_indices, sampler_indices_padded,
  86. embeddings_indices, long_lora_indices).
  87. """
  88. index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
  89. embedding_indices = index_mapping_indices.copy()
  90. lora_indices = index_mapping_indices.copy()
  91. long_lora_offsets: Optional[torch.Tensor] = None
  92. if long_lora_context:
  93. long_lora_offsets = torch.zeros(len(index_mapping_indices),
  94. device="cuda",
  95. dtype=torch.long)
  96. prompt_mapping: List[int] = [
  97. lora_index_to_id.index(x) if x > 0 else -1
  98. for x in mapping.prompt_mapping
  99. ]
  100. lora_idx = None
  101. for i in range(len(index_mapping_indices)):
  102. # TODO index can be slow. optimize
  103. lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
  104. if index_mapping_indices[i] > 0 else -1)
  105. embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
  106. lora_indices[i] = lora_idx
  107. if long_lora_context:
  108. assert long_lora_offsets is not None
  109. lora_offset: int = long_lora_context.offsets_by_lora_id.get(
  110. index_mapping_indices[i], 0)
  111. long_lora_offsets[i] = lora_offset
  112. indices_list: List[Union[List[int], torch.Tensor]] = [
  113. index_mapping_indices,
  114. lora_indices,
  115. embedding_indices,
  116. ]
  117. if long_lora_context:
  118. assert long_lora_offsets is not None
  119. indices_list.append(long_lora_offsets)
  120. indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
  121. prompt_mapping_tensor = torch.tensor(prompt_mapping,
  122. device="cuda",
  123. dtype=torch.long)
  124. embeddings_indices = torch.stack([
  125. indices[2] * extra_vocab_size,
  126. indices[2] * (vocab_size + extra_vocab_size),
  127. ])
  128. embeddings_indices[embeddings_indices == -1] = max_loras - 1
  129. base_indices = indices[1]
  130. sampler_indices = prompt_mapping_tensor
  131. sampler_indices_padded = sampler_indices.clone()
  132. sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
  133. sampler_indices_padded = torch.arange(
  134. 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (
  135. sampler_indices_padded * len(sampler_indices_padded))
  136. long_lora_indices = None
  137. long_lora_indices_len: Optional[int] = None
  138. if long_lora_context:
  139. long_lora_indices = indices[3]
  140. long_lora_indices_len = long_lora_indices.shape[-1]
  141. # Contain length of indices tensors. Used to index into each tensor.
  142. indices_len = [
  143. base_indices.shape[-1],
  144. sampler_indices.shape[-1],
  145. sampler_indices_padded.shape[-1],
  146. embeddings_indices.shape[-1],
  147. ]
  148. if long_lora_indices_len is not None:
  149. indices_len.append(long_lora_indices_len)
  150. else:
  151. # If long_lora doesn't exist,append None
  152. indices_len.append(None)
  153. return (
  154. base_indices,
  155. sampler_indices,
  156. sampler_indices_padded,
  157. embeddings_indices,
  158. long_lora_indices,
  159. indices_len,
  160. )
  161. class PunicaWrapper:
  162. """
  163. PunicaWrapper is designed to manage and provide metadata for the punica
  164. kernel. The main function is to maintain the state information for
  165. Multi-LoRA, and to provide the interface for the punica kernel.
  166. """
  167. def __init__(self, max_num_batched_tokens: int, max_batches: int,
  168. device: str):
  169. self._token_lora_indices = torch.empty(max_num_batched_tokens,
  170. dtype=torch.long,
  171. device=device)
  172. self._sampler_indices = torch.empty(max_num_batched_tokens,
  173. dtype=torch.long,
  174. device=device)
  175. self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
  176. dtype=torch.long,
  177. device=device)
  178. self._embeddings_indices = torch.empty(2,
  179. max_num_batched_tokens,
  180. dtype=torch.long,
  181. device=device)
  182. self._long_lora_indices = torch.empty(max_num_batched_tokens,
  183. dtype=torch.long,
  184. device=device)
  185. # 5 is the number of indicies tensors.
  186. # base_indices, sampler_indices, sampler_indices_padded,
  187. # embeddings_indices,long_lora_indices
  188. self.indices_len: List[Optional[int]] = [None] * 5
  189. # these attributes are the information required for sgmv kernel
  190. self._seq_start_locs = torch.empty(max_batches,
  191. dtype=torch.long,
  192. device=device)
  193. self._seq_lengths = torch.empty(max_batches,
  194. dtype=torch.long,
  195. device=device)
  196. self._lora_indices_per_batch = torch.empty(max_batches,
  197. dtype=torch.long,
  198. device=device)
  199. self.max_length: int = 0
  200. self.batch_size: int = -1
  201. self.is_prefill = False
  202. self.no_lora = False
  203. def update_metadata(
  204. self,
  205. mapping: "LoRAMapping",
  206. lora_index_to_id: List[Optional[int]],
  207. max_loras: int,
  208. vocab_size: int,
  209. extra_vocab_size: int,
  210. long_lora_context: Optional["LongContextLoRAContext"] = None,
  211. ):
  212. self._update_base_metadata(mapping, lora_index_to_id, max_loras,
  213. vocab_size, extra_vocab_size,
  214. long_lora_context)
  215. if mapping.is_prefill:
  216. # Update metadata required for prefill-related operators.
  217. self._update_prefill_metada(self.token_lora_indices)
  218. self.is_prefill = True
  219. else:
  220. self.is_prefill = False
  221. def _update_base_metadata(
  222. self,
  223. mapping: "LoRAMapping",
  224. lora_index_to_id: List[Optional[int]],
  225. max_loras: int,
  226. vocab_size: int,
  227. extra_vocab_size: int,
  228. long_lora_context: Optional["LongContextLoRAContext"] = None,
  229. ):
  230. (
  231. base_indices,
  232. sampler_indices,
  233. sampler_indices_padded,
  234. embeddings_indices,
  235. long_lora_offsets_tensor,
  236. indices_len,
  237. ) = convert_mapping(
  238. mapping,
  239. lora_index_to_id,
  240. max_loras,
  241. vocab_size,
  242. extra_vocab_size,
  243. long_lora_context,
  244. )
  245. self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
  246. self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
  247. self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
  248. sampler_indices_padded)
  249. self._embeddings_indices[:embeddings_indices.
  250. shape[0], :embeddings_indices.shape[1]].copy_(
  251. embeddings_indices)
  252. if long_lora_offsets_tensor is not None:
  253. self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
  254. long_lora_offsets_tensor)
  255. else:
  256. self._long_lora_indices.zero_()
  257. self.indices_len[:] = indices_len
  258. def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
  259. (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
  260. batch_size, max_length, no_lora) = compute_meta(token_lora_tensor)
  261. self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
  262. b_seq_start_tensor)
  263. self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
  264. self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
  265. lora_indices_tensor)
  266. self.batch_size = batch_size
  267. self.max_length = max_length
  268. self.no_lora = no_lora
  269. @property
  270. def prefill_metadata(
  271. self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
  272. """
  273. This property provides a convenient way to access the necessary
  274. metadata for prefill-related kernel computations.
  275. 1. seq_start_locs: Tensor of sequence start positions
  276. 2. seq_lengths: Tensor of sequence lengths
  277. 3. lora_indices_per_batch: Tensor of lora indices, and an index of
  278. -1 means no lora should be applied.
  279. 4. batch_size: batch size after clustering identical lora indices
  280. 5. max_length: The maximum sequence length in the batch
  281. """
  282. return (self._seq_start_locs[:self.batch_size],
  283. self._seq_lengths[:self.batch_size],
  284. self._lora_indices_per_batch[:self.batch_size],
  285. self.batch_size, self.max_length)
  286. @property
  287. def token_lora_indices(self) -> torch.Tensor:
  288. """
  289. This property provides the lora indices corresponding to each token
  290. in the batch. An index of -1 means no lora should be applied.
  291. """
  292. token_lora_len = self.indices_len[0]
  293. return self._token_lora_indices[:token_lora_len]
  294. @property
  295. def sampler_indices(self) -> torch.Tensor:
  296. """
  297. This property is used to access the lora indices specifically for
  298. LogitsProcessorWithLoRA
  299. """
  300. sampler_indices_len = self.indices_len[1]
  301. return self._sampler_indices[:sampler_indices_len]
  302. @property
  303. def sampler_indices_padded(self) -> torch.Tensor:
  304. """
  305. This property provides access to padded sampler indices
  306. """
  307. indices_padded_len = self.indices_len[2]
  308. return self._sampler_indices_padded[:indices_padded_len]
  309. @property
  310. def embeddings_indices(self) -> torch.Tensor:
  311. """
  312. This property provides access to the indices used for lora embeddings,
  313. specifically for VocabParallelEmbeddingWithLoRA
  314. """
  315. embeddings_indices_len = self.indices_len[3]
  316. return self._embeddings_indices[:, :embeddings_indices_len]
  317. @property
  318. def long_lora_indices(self) -> torch.Tensor:
  319. """
  320. This property provides access to the indices used for long context
  321. lora, specifically for LinearScalingRotaryEmbeddingWithLora
  322. """
  323. long_lora_len = self.indices_len[4]
  324. return self._long_lora_indices[:long_lora_len]
  325. def shrink_prefill(
  326. self,
  327. y: torch.Tensor,
  328. x: torch.Tensor,
  329. w_t_all: torch.Tensor,
  330. scale: float,
  331. ):
  332. #No LoRA request, so return directly
  333. if self.no_lora:
  334. return
  335. sgmv_shrink(
  336. x,
  337. w_t_all,
  338. y,
  339. *self.prefill_metadata,
  340. scale,
  341. )
  342. def shrink_decode(
  343. self,
  344. y: torch.Tensor,
  345. x: torch.Tensor,
  346. w_t_all: torch.Tensor,
  347. scale: float,
  348. ):
  349. bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
  350. def expand_prefill(
  351. self,
  352. y: torch.Tensor,
  353. x: torch.Tensor,
  354. w_t_all: torch.Tensor,
  355. add_input: bool,
  356. ):
  357. #No LoRA request, so return directly
  358. if self.no_lora:
  359. return
  360. sgmv_expand(
  361. x,
  362. w_t_all,
  363. y,
  364. *self.prefill_metadata,
  365. add_input,
  366. )
  367. def expand_decode(
  368. self,
  369. y: torch.Tensor,
  370. x: torch.Tensor,
  371. w_t_all: torch.Tensor,
  372. add_input: bool,
  373. ):
  374. bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
  375. def expand_slice_prefill(
  376. self,
  377. y: torch.Tensor,
  378. x: torch.Tensor,
  379. w_t_all: torch.Tensor,
  380. y_offset: Optional[int],
  381. y_slice_size: Optional[int],
  382. add_input: bool,
  383. ):
  384. #No LoRA request, so return directly
  385. if self.no_lora:
  386. return
  387. sgmv_expand_slice(
  388. x,
  389. w_t_all,
  390. y,
  391. *self.prefill_metadata,
  392. y_offset,
  393. y_slice_size,
  394. add_input,
  395. )
  396. def expand_slice_decode(
  397. self,
  398. y: torch.Tensor,
  399. x: torch.Tensor,
  400. w_t_all: torch.Tensor,
  401. y_offset: Optional[int],
  402. y_slice_size: Optional[int],
  403. add_input: bool,
  404. ):
  405. bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
  406. y_slice_size, add_input)
  407. def add_shrink(
  408. self,
  409. y: torch.Tensor,
  410. x: torch.Tensor,
  411. w_t_all: torch.Tensor,
  412. scale: float,
  413. ):
  414. """
  415. Perform the ` y+=x@w_t_all` computation, which is suitable for the
  416. GEMM of lora'a.
  417. When `is_prefill is` true, it indicates that it is currently the
  418. prefill stage, and the `shrink_prefill` function should be called.
  419. Otherwise, it is the decode stage, and the shrink_decode function
  420. should be called.
  421. """
  422. shrink_fun: Callable = (self.shrink_prefill
  423. if self.is_prefill else self.shrink_decode)
  424. shrink_fun(y, x, w_t_all, scale)
  425. def add_expand(
  426. self,
  427. y: torch.Tensor,
  428. x: torch.Tensor,
  429. w_t_all: torch.Tensor,
  430. add_input: bool = True,
  431. ):
  432. """
  433. Perform the ` y+=x@w_t_all` computation, which is suitable for the
  434. GEMM of lora'b.
  435. When `is_prefill` is true, it indicates that it is currently the
  436. prefill stage, and the `expand_prefill` function should be called.
  437. Otherwise, it is the decode stage, and the expand_decode function
  438. should be called.
  439. """
  440. expand_fun: Callable = (self.expand_prefill
  441. if self.is_prefill else self.expand_decode)
  442. expand_fun(y, x, w_t_all, add_input)
  443. def add_expand_slice(self,
  444. y: torch.Tensor,
  445. x: torch.Tensor,
  446. w_t_all: torch.Tensor,
  447. y_offset: Optional[int],
  448. y_slice_size: Optional[int],
  449. add_input: bool = True):
  450. """
  451. Similar to `add_expand`
  452. """
  453. expand_slice_fun: Callable = (self.expand_slice_prefill
  454. if self.is_prefill else
  455. self.expand_slice_decode)
  456. expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
  457. def add_lora(self,
  458. y: torch.Tensor,
  459. x: torch.Tensor,
  460. wa_t_all: torch.Tensor,
  461. wb_t_all: torch.Tensor,
  462. scale: float,
  463. y_offset: Optional[int] = None,
  464. y_slice_size: Optional[int] = None,
  465. *,
  466. buffer: Optional[torch.Tensor] = None) -> None:
  467. """
  468. Semantics:
  469. y[i] += (
  470. x[i].unsqueeze(0)
  471. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  472. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  473. * scale
  474. ).squeeze(0)
  475. Args:
  476. y (torch.Tensor): Output tensor. Will be changed in-place.
  477. x (torch.Tensor): Input tensor
  478. wa_t_all (torch.Tensor): lora_a's weight
  479. wb_t_all (torch.Tensor): lora_b's weight
  480. scale (float): Scaling factor.
  481. y_offset (Optional[int], optional): Offset to apply to the starting
  482. column of y.
  483. y_slice_size (Optional[int], optional): Size of the y column slice..
  484. buffer (Optional[torch.Tensor], optional): Defaults to None.
  485. """
  486. y_org = y
  487. y = y.view(-1, y.shape[-1])
  488. x = x.view(-1, x.shape[-1])
  489. r = wb_t_all.size(-1)
  490. if buffer is None:
  491. # We set the buffer to be float32 by default ,refer to:
  492. # https://github.com/triton-lang/triton/issues/1387
  493. buffer = torch.zeros((x.size(0), r),
  494. dtype=torch.float32,
  495. device=x.device)
  496. self.add_shrink(buffer, x, wa_t_all, scale)
  497. if y_offset is None and y_slice_size is None:
  498. self.add_expand(y, buffer, wb_t_all, add_input=True)
  499. else:
  500. self.add_expand_slice(y,
  501. buffer,
  502. wb_t_all,
  503. y_offset,
  504. y_slice_size,
  505. add_input=True)
  506. y = y.view_as(y_org)
  507. def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
  508. lora_a_stacked: Tuple[torch.Tensor,
  509. torch.Tensor,
  510. torch.Tensor],
  511. lora_b_stacked: Tuple[torch.Tensor,
  512. torch.Tensor,
  513. torch.Tensor],
  514. scale: float,
  515. output_slices: Tuple[int, ...]) -> None:
  516. """
  517. Applies lora to each input. Similar to add_lora, This method is
  518. used for layers that are composed of multiple sublayers
  519. (slices) packed together.
  520. """
  521. y_org = y
  522. x = x.view(-1, x.shape[-1])
  523. y = y.view(-1, y.shape[-1])
  524. offset_left = 0
  525. # TODO fuse these kernels
  526. for slice_idx in range(len(output_slices)):
  527. self.add_lora(y, x, lora_a_stacked[slice_idx],
  528. lora_b_stacked[slice_idx], scale, offset_left,
  529. output_slices[slice_idx])
  530. offset_left += output_slices[slice_idx]
  531. y = y.view_as(y_org)
  532. def add_lora_logits(self,
  533. y: torch.Tensor,
  534. x: torch.Tensor,
  535. wa_t_all: torch.Tensor,
  536. wb_t_all: torch.Tensor,
  537. scale,
  538. *,
  539. buffer: Optional[torch.Tensor] = None) -> None:
  540. """
  541. LogitsProcessorWithLoRA always using bgmv
  542. """
  543. y_org = y
  544. y = y.view(-1, y.shape[-1])
  545. x = x.view(-1, x.shape[-1])
  546. r = wb_t_all.size(-1)
  547. if buffer is None:
  548. # We set the buffer to be float32 by default ,refer to:
  549. # https://github.com/triton-lang/triton/issues/1387
  550. buffer = torch.zeros((x.size(0), r),
  551. dtype=torch.float32,
  552. device=x.device)
  553. bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale)
  554. bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True)
  555. y = y.view_as(y_org)