123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- # Based on code from https://github.com/punica-ai/punica
- from typing import Optional
- import torch
- def _raise_import_error(e):
- if torch.cuda.get_device_capability() < (8, 0):
- raise ImportError(
- "punica LoRA kernels require compute capability >= 8.0") from e
- else:
- raise ImportError(
- "punica LoRA kernels could not be imported. If you built Aphrodite "
- "from source, make sure APHRODITE_INSTALL_PUNICA_KERNELS=1 env var "
- "was set.") from e
- def bgmv(
- y: torch.Tensor,
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- indicies: torch.LongTensor,
- layer_idx: int,
- scale: float,
- ):
- """
- Semantics:
- y[i] += (
- x[i].unsqueeze(0)
- @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
- * scale
- ).squeeze(0)
- Args:
- y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
- x: Shape: `[B, H1]`. Input vectors.
- w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
- matrices.
- indicies: Shape: `[B]`. Indices of the weight matrices.
- layer_idx: Layer index of the weight matrices.
- scale: Scaling factor.
- """
- try:
- import aphrodite._punica_C as punica_kernels
- except ImportError as e:
- _raise_import_error(e)
- punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
- def add_lora(y: torch.Tensor,
- x: torch.Tensor,
- wa_t_all: torch.Tensor,
- wb_t_all: torch.Tensor,
- indicies: torch.LongTensor,
- layer_idx: int,
- scale: float,
- *,
- buffer: Optional[torch.Tensor] = None):
- """
- Semantics:
- y[i] += (
- x[i].unsqueeze(0)
- @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
- @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
- * scale
- ).squeeze(0)
- Args:
- y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
- x: Shape: `[B, H1]`. Input vectors.
- wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
- LoRA A matrices.
- wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
- LoRA B matrices.
- indicies: Shape: `[B]`. Indices of the LoRA weights.
- layer_idx: Layer index of LoRA weights.
- scale: Scaling factor.
- buffer: Optional. Shape: `[B, R]`. Temporary buffer.
- """
- try:
- import aphrodite._punica_C as punica_kernels
- except ImportError as e:
- _raise_import_error(e)
- r = wb_t_all.size(-1)
- if buffer is None:
- # We set the buffer to be float32 by default to avoid
- # numerical inaccuracies that would otherwise happen
- # due to downcasting.
- buffer = torch.zeros((x.size(0), r),
- dtype=torch.float32,
- device=x.device)
- punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
- punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
- scale)
- def add_lora_slice(y: torch.Tensor,
- x: torch.Tensor,
- wa_t_all: torch.Tensor,
- wb_t_all: torch.Tensor,
- indicies: torch.LongTensor,
- layer_idx: int,
- scale: float,
- y_offset: int,
- y_slice_size: int,
- *,
- buffer: Optional[torch.Tensor] = None):
- """
- Same as `add_lora` but you can operate on slices of y.
- Pass whole y, define y_offset and y_slice_size.
- Semantics:
- y[i] += (
- x[i].unsqueeze(0)
- @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
- @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
- * scale
- ).squeeze(0)
- Args:
- y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
- x: Shape: `[B, H1]`. Input vectors.
- wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
- LoRA A matrices.
- wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
- LoRA B matrices.
- indicies: Shape: `[B]`. Indices of the LoRA weights.
- layer_idx: Layer index of LoRA weights.
- scale: Scaling factor.
- y_offset: Offset to apply to the starting column of y.
- y_slice_size: Size of the y column slice.
- """
- try:
- import aphrodite._punica_C as punica_kernels
- except ImportError as e:
- _raise_import_error(e)
- r = wb_t_all.size(-1)
- if buffer is None:
- # We set the buffer to be float32 by default to avoid
- # numerical inaccuracies that would otherwise happen
- # due to downcasting.
- buffer = torch.zeros((x.size(0), r),
- dtype=torch.float32,
- device=x.device)
- punica_kernels.dispatch_bgmv_low_level(
- buffer,
- x,
- wa_t_all,
- indicies,
- layer_idx,
- 1.0,
- x.size(1),
- buffer.size(1),
- 0,
- )
- punica_kernels.dispatch_bgmv_low_level(
- y,
- buffer,
- wb_t_all,
- indicies,
- layer_idx,
- scale,
- buffer.size(1),
- y_slice_size,
- y_offset,
- )
|