123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604 |
- """
- Based on:
- Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
- Punica: Multi-Tenant LoRA Serving.
- https://arxiv.org/abs/2310.18547
- """
- from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
- import torch
- from aphrodite.triton_utils import HAS_TRITON
- if HAS_TRITON:
- from aphrodite.lora.ops.bgmv_expand import bgmv_expand
- from aphrodite.lora.ops.bgmv_expand_slice import bgmv_expand_slice
- from aphrodite.lora.ops.bgmv_shrink import bgmv_shrink
- from aphrodite.lora.ops.sgmv_expand import sgmv_expand
- from aphrodite.lora.ops.sgmv_expand_slice import sgmv_expand_slice
- from aphrodite.lora.ops.sgmv_shrink import sgmv_shrink
- if TYPE_CHECKING:
- # avoid circuit import
- from aphrodite.lora.layers import LoRAMapping
- from aphrodite.lora.models import LongContextLoRAContext
- def compute_meta(
- token_lora_tensor: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]:
- """
- Get the information required for the sgmv kernel. With the features:
- 1. If consecutive requests in the batch use the same LoRA, this function
- will combine them into a single request, improving sgmv kernel inference
- performance.
- 2. At the beginning of each prefill stage inference, recalculations are
- needed based on the input, but only once.
- """
- lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
- token_lora_tensor, return_counts=True)
- cum_result = torch.cumsum(seq_length_tensor, dim=0)
- b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
- b_seq_start_tensor[1:].copy_(cum_result[:-1])
- max_length = seq_length_tensor.max().item()
- batch_size = lora_indices_tensor.size(0)
- no_lora = False
- # -1 means no lora should be applied. Use `no_lora` to determine whether
- # the current step requires LoRA. If LoRA is not needed, the prefill stage
- # does not need to launch the triton kernel, which can improve performance
- if batch_size == 1 and lora_indices_tensor == -1:
- no_lora = True
- return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
- batch_size, max_length, no_lora)
- # TODO see if this can be vectorized
- def convert_mapping(
- mapping: "LoRAMapping",
- lora_index_to_id: List[Optional[int]],
- max_loras: int,
- vocab_size: int,
- extra_vocab_size: int,
- long_lora_context: Optional["LongContextLoRAContext"] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
- Optional[torch.Tensor], List[int]]:
- """Converts LoRAMapping to index tensors.
- Args:
- mapping: LoRAMapping mapping rows in a batch to LoRA ids.
- lora_index_to_id: List mapping LoRA ids to LoRA indices.
- max_loras: Maximum number of LoRAs.
- vocab_size: Model vocab size.
- extra_vocab_size: Extra vocab size each LoRA can have.
- long_lora_context: Passed if there are long context lora in a batch.
- Returns:
- A tuple of tensors:
- base_indices: Tensor of shape [batch_size] mapping batch rows to
- LoRA indices.
- sampler_indices: Tensor of shape [batch_size] mapping requests to
- LoRA indices for sampler. For generation, this will be the
- same as base_indicies. For prefill, this will map requests
- to LoRA indices.
- sampler_indices_padded: Tensor of shape [batch_size] mapping
- requests to LoRA indices for sampler with padding.
- Same as sampler_indicies, but -1 is replaced with
- max_loras.
- embeddings_indices: Tensor of shape [2, batch_size] mapping
- requests to embedding indices. First row is for embeddings
- added by the LoRAs, second row is for the LoRA.lora_a
- embeddings.
- long_lora_indices: Tensor of shape [batch_size] mapping
- requests to RoPE offsets and rot dims for long LoRAs.
- None if long context lora doesn't exist.
- indices_len: List of lengths of the above tensors. It contains
- (base_indices, sampler_indices, sampler_indices_padded,
- embeddings_indices, long_lora_indices).
- """
- index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
- embedding_indices = index_mapping_indices.copy()
- lora_indices = index_mapping_indices.copy()
- long_lora_offsets: Optional[torch.Tensor] = None
- if long_lora_context:
- long_lora_offsets = torch.zeros(len(index_mapping_indices),
- device="cuda",
- dtype=torch.long)
- prompt_mapping: List[int] = [
- lora_index_to_id.index(x) if x > 0 else -1
- for x in mapping.prompt_mapping
- ]
- lora_idx = None
- for i in range(len(index_mapping_indices)):
- # TODO index can be slow. optimize
- lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
- if index_mapping_indices[i] > 0 else -1)
- embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
- lora_indices[i] = lora_idx
- if long_lora_context:
- assert long_lora_offsets is not None
- lora_offset: int = long_lora_context.offsets_by_lora_id.get(
- index_mapping_indices[i], 0)
- long_lora_offsets[i] = lora_offset
- indices_list: List[Union[List[int], torch.Tensor]] = [
- index_mapping_indices,
- lora_indices,
- embedding_indices,
- ]
- if long_lora_context:
- assert long_lora_offsets is not None
- indices_list.append(long_lora_offsets)
- indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
- prompt_mapping_tensor = torch.tensor(prompt_mapping,
- device="cuda",
- dtype=torch.long)
- embeddings_indices = torch.stack([
- indices[2] * extra_vocab_size,
- indices[2] * (vocab_size + extra_vocab_size),
- ])
- embeddings_indices[embeddings_indices == -1] = max_loras - 1
- base_indices = indices[1]
- sampler_indices = prompt_mapping_tensor
- sampler_indices_padded = sampler_indices.clone()
- sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
- sampler_indices_padded = torch.arange(
- 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (
- sampler_indices_padded * len(sampler_indices_padded))
- long_lora_indices = None
- long_lora_indices_len: Optional[int] = None
- if long_lora_context:
- long_lora_indices = indices[3]
- long_lora_indices_len = long_lora_indices.shape[-1]
- # Contain length of indices tensors. Used to index into each tensor.
- indices_len = [
- base_indices.shape[-1],
- sampler_indices.shape[-1],
- sampler_indices_padded.shape[-1],
- embeddings_indices.shape[-1],
- ]
- if long_lora_indices_len is not None:
- indices_len.append(long_lora_indices_len)
- else:
- # If long_lora doesn't exist,append None
- indices_len.append(None)
- return (
- base_indices,
- sampler_indices,
- sampler_indices_padded,
- embeddings_indices,
- long_lora_indices,
- indices_len,
- )
- class PunicaWrapper:
- """
- PunicaWrapper is designed to manage and provide metadata for the punica
- kernel. The main function is to maintain the state information for
- Multi-LoRA, and to provide the interface for the punica kernel.
- """
- def __init__(self, max_num_batched_tokens: int, max_batches: int,
- device: str):
- self._token_lora_indices = torch.empty(max_num_batched_tokens,
- dtype=torch.long,
- device=device)
- self._sampler_indices = torch.empty(max_num_batched_tokens,
- dtype=torch.long,
- device=device)
- self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
- dtype=torch.long,
- device=device)
- self._embeddings_indices = torch.empty(2,
- max_num_batched_tokens,
- dtype=torch.long,
- device=device)
- self._long_lora_indices = torch.empty(max_num_batched_tokens,
- dtype=torch.long,
- device=device)
- # 5 is the number of indicies tensors.
- # base_indices, sampler_indices, sampler_indices_padded,
- # embeddings_indices,long_lora_indices
- self.indices_len: List[Optional[int]] = [None] * 5
- # these attributes are the information required for sgmv kernel
- self._seq_start_locs = torch.empty(max_batches,
- dtype=torch.long,
- device=device)
- self._seq_lengths = torch.empty(max_batches,
- dtype=torch.long,
- device=device)
- self._lora_indices_per_batch = torch.empty(max_batches,
- dtype=torch.long,
- device=device)
- self.max_length: int = 0
- self.batch_size: int = -1
- self.is_prefill = False
- self.no_lora = False
- def update_metadata(
- self,
- mapping: "LoRAMapping",
- lora_index_to_id: List[Optional[int]],
- max_loras: int,
- vocab_size: int,
- extra_vocab_size: int,
- long_lora_context: Optional["LongContextLoRAContext"] = None,
- ):
- self._update_base_metadata(mapping, lora_index_to_id, max_loras,
- vocab_size, extra_vocab_size,
- long_lora_context)
- if mapping.is_prefill:
- # Update metadata required for prefill-related operators.
- self._update_prefill_metada(self.token_lora_indices)
- self.is_prefill = True
- else:
- self.is_prefill = False
- def _update_base_metadata(
- self,
- mapping: "LoRAMapping",
- lora_index_to_id: List[Optional[int]],
- max_loras: int,
- vocab_size: int,
- extra_vocab_size: int,
- long_lora_context: Optional["LongContextLoRAContext"] = None,
- ):
- (
- base_indices,
- sampler_indices,
- sampler_indices_padded,
- embeddings_indices,
- long_lora_offsets_tensor,
- indices_len,
- ) = convert_mapping(
- mapping,
- lora_index_to_id,
- max_loras,
- vocab_size,
- extra_vocab_size,
- long_lora_context,
- )
- self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
- self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
- self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
- sampler_indices_padded)
- self._embeddings_indices[:embeddings_indices.
- shape[0], :embeddings_indices.shape[1]].copy_(
- embeddings_indices)
- if long_lora_offsets_tensor is not None:
- self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
- long_lora_offsets_tensor)
- else:
- self._long_lora_indices.zero_()
- self.indices_len[:] = indices_len
- def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
- (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
- batch_size, max_length, no_lora) = compute_meta(token_lora_tensor)
- self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
- b_seq_start_tensor)
- self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
- self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
- lora_indices_tensor)
- self.batch_size = batch_size
- self.max_length = max_length
- self.no_lora = no_lora
- @property
- def prefill_metadata(
- self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
- """
- This property provides a convenient way to access the necessary
- metadata for prefill-related kernel computations.
- 1. seq_start_locs: Tensor of sequence start positions
- 2. seq_lengths: Tensor of sequence lengths
- 3. lora_indices_per_batch: Tensor of lora indices, and an index of
- -1 means no lora should be applied.
- 4. batch_size: batch size after clustering identical lora indices
- 5. max_length: The maximum sequence length in the batch
- """
- return (self._seq_start_locs[:self.batch_size],
- self._seq_lengths[:self.batch_size],
- self._lora_indices_per_batch[:self.batch_size],
- self.batch_size, self.max_length)
- @property
- def token_lora_indices(self) -> torch.Tensor:
- """
- This property provides the lora indices corresponding to each token
- in the batch. An index of -1 means no lora should be applied.
- """
- token_lora_len = self.indices_len[0]
- return self._token_lora_indices[:token_lora_len]
- @property
- def sampler_indices(self) -> torch.Tensor:
- """
- This property is used to access the lora indices specifically for
- LogitsProcessorWithLoRA
- """
- sampler_indices_len = self.indices_len[1]
- return self._sampler_indices[:sampler_indices_len]
- @property
- def sampler_indices_padded(self) -> torch.Tensor:
- """
- This property provides access to padded sampler indices
- """
- indices_padded_len = self.indices_len[2]
- return self._sampler_indices_padded[:indices_padded_len]
- @property
- def embeddings_indices(self) -> torch.Tensor:
- """
- This property provides access to the indices used for lora embeddings,
- specifically for VocabParallelEmbeddingWithLoRA
- """
- embeddings_indices_len = self.indices_len[3]
- return self._embeddings_indices[:, :embeddings_indices_len]
- @property
- def long_lora_indices(self) -> torch.Tensor:
- """
- This property provides access to the indices used for long context
- lora, specifically for LinearScalingRotaryEmbeddingWithLora
- """
- long_lora_len = self.indices_len[4]
- return self._long_lora_indices[:long_lora_len]
- def shrink_prefill(
- self,
- y: torch.Tensor,
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- scale: float,
- ):
- #No LoRA request, so return directly
- if self.no_lora:
- return
- sgmv_shrink(
- x,
- w_t_all,
- y,
- *self.prefill_metadata,
- scale,
- )
- def shrink_decode(
- self,
- y: torch.Tensor,
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- scale: float,
- ):
- bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
- def expand_prefill(
- self,
- y: torch.Tensor,
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- add_input: bool,
- ):
- #No LoRA request, so return directly
- if self.no_lora:
- return
- sgmv_expand(
- x,
- w_t_all,
- y,
- *self.prefill_metadata,
- add_input,
- )
- def expand_decode(
- self,
- y: torch.Tensor,
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- add_input: bool,
- ):
- bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
- def expand_slice_prefill(
- self,
- y: torch.Tensor,
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- y_offset: Optional[int],
- y_slice_size: Optional[int],
- add_input: bool,
- ):
- #No LoRA request, so return directly
- if self.no_lora:
- return
- sgmv_expand_slice(
- x,
- w_t_all,
- y,
- *self.prefill_metadata,
- y_offset,
- y_slice_size,
- add_input,
- )
- def expand_slice_decode(
- self,
- y: torch.Tensor,
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- y_offset: Optional[int],
- y_slice_size: Optional[int],
- add_input: bool,
- ):
- bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
- y_slice_size, add_input)
- def add_shrink(
- self,
- y: torch.Tensor,
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- scale: float,
- ):
- """
- Perform the ` y+=x@w_t_all` computation, which is suitable for the
- GEMM of lora'a.
- When `is_prefill is` true, it indicates that it is currently the
- prefill stage, and the `shrink_prefill` function should be called.
- Otherwise, it is the decode stage, and the shrink_decode function
- should be called.
- """
- shrink_fun: Callable = (self.shrink_prefill
- if self.is_prefill else self.shrink_decode)
- shrink_fun(y, x, w_t_all, scale)
- def add_expand(
- self,
- y: torch.Tensor,
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- add_input: bool = True,
- ):
- """
- Perform the ` y+=x@w_t_all` computation, which is suitable for the
- GEMM of lora'b.
- When `is_prefill` is true, it indicates that it is currently the
- prefill stage, and the `expand_prefill` function should be called.
- Otherwise, it is the decode stage, and the expand_decode function
- should be called.
- """
- expand_fun: Callable = (self.expand_prefill
- if self.is_prefill else self.expand_decode)
- expand_fun(y, x, w_t_all, add_input)
- def add_expand_slice(self,
- y: torch.Tensor,
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- y_offset: Optional[int],
- y_slice_size: Optional[int],
- add_input: bool = True):
- """
- Similar to `add_expand`
- """
- expand_slice_fun: Callable = (self.expand_slice_prefill
- if self.is_prefill else
- self.expand_slice_decode)
- expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
- def add_lora(self,
- y: torch.Tensor,
- x: torch.Tensor,
- wa_t_all: torch.Tensor,
- wb_t_all: torch.Tensor,
- scale: float,
- y_offset: Optional[int] = None,
- y_slice_size: Optional[int] = None,
- *,
- buffer: Optional[torch.Tensor] = None) -> None:
- """
- Semantics:
- y[i] += (
- x[i].unsqueeze(0)
- @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
- @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
- * scale
- ).squeeze(0)
- Args:
- y (torch.Tensor): Output tensor. Will be changed in-place.
- x (torch.Tensor): Input tensor
- wa_t_all (torch.Tensor): lora_a's weight
- wb_t_all (torch.Tensor): lora_b's weight
- scale (float): Scaling factor.
- y_offset (Optional[int], optional): Offset to apply to the starting
- column of y.
- y_slice_size (Optional[int], optional): Size of the y column slice..
- buffer (Optional[torch.Tensor], optional): Defaults to None.
- """
- y_org = y
- y = y.view(-1, y.shape[-1])
- x = x.view(-1, x.shape[-1])
- r = wb_t_all.size(-1)
- if buffer is None:
- # We set the buffer to be float32 by default ,refer to:
- # https://github.com/triton-lang/triton/issues/1387
- buffer = torch.zeros((x.size(0), r),
- dtype=torch.float32,
- device=x.device)
- self.add_shrink(buffer, x, wa_t_all, scale)
- if y_offset is None and y_slice_size is None:
- self.add_expand(y, buffer, wb_t_all, add_input=True)
- else:
- self.add_expand_slice(y,
- buffer,
- wb_t_all,
- y_offset,
- y_slice_size,
- add_input=True)
- y = y.view_as(y_org)
- def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
- lora_a_stacked: Tuple[torch.Tensor,
- torch.Tensor,
- torch.Tensor],
- lora_b_stacked: Tuple[torch.Tensor,
- torch.Tensor,
- torch.Tensor],
- scale: float,
- output_slices: Tuple[int, ...]) -> None:
- """
- Applies lora to each input. Similar to add_lora, This method is
- used for layers that are composed of multiple sublayers
- (slices) packed together.
- """
- y_org = y
- x = x.view(-1, x.shape[-1])
- y = y.view(-1, y.shape[-1])
- offset_left = 0
- # TODO fuse these kernels
- for slice_idx in range(len(output_slices)):
- self.add_lora(y, x, lora_a_stacked[slice_idx],
- lora_b_stacked[slice_idx], scale, offset_left,
- output_slices[slice_idx])
- offset_left += output_slices[slice_idx]
- y = y.view_as(y_org)
- def add_lora_logits(self,
- y: torch.Tensor,
- x: torch.Tensor,
- wa_t_all: torch.Tensor,
- wb_t_all: torch.Tensor,
- scale,
- *,
- buffer: Optional[torch.Tensor] = None) -> None:
- """
- LogitsProcessorWithLoRA always using bgmv
- """
- y_org = y
- y = y.view(-1, y.shape[-1])
- x = x.view(-1, x.shape[-1])
- r = wb_t_all.size(-1)
- if buffer is None:
- # We set the buffer to be float32 by default ,refer to:
- # https://github.com/triton-lang/triton/issues/1387
- buffer = torch.zeros((x.size(0), r),
- dtype=torch.float32,
- device=x.device)
- bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale)
- bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True)
- y = y.view_as(y_org)
|