123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- """
- Based on:
- Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
- Punica: Multi-Tenant LoRA Serving.
- https://arxiv.org/abs/2310.18547
- """
- import torch
- import triton
- import triton.language as tl
- from .utils import get_lora_op_configs
- @triton.jit
- def _bgmv_expand_slice_kernel(
- input_ptr,
- lora_ptr,
- out_ptr,
- N,
- K,
- lora_indices,
- xm_stride,
- xk_stride,
- l0_stride,
- lora_k_stride,
- lora_n_stride,
- cm_stride,
- cn_stride,
- slice_offset,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
- SPLIT_N: tl.constexpr,
- EVEN_K: tl.constexpr,
- ADD_INPUTS: tl.constexpr,
- CAST_TYPE: tl.constexpr,
- ):
- """
- GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
- performance
- """
- pid_sn = tl.program_id(axis=0)
- cur_batch = tl.program_id(axis=1)
- lora_index = tl.load(lora_indices + cur_batch)
- if lora_index == -1:
- return
- offset_k = tl.arange(0, BLOCK_K)
- offset_n = tl.arange(0, BLOCK_N)
- if EVEN_K:
- tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
- offset_k * xk_stride, ) # [BLOCK_K]
- else:
- tiled_a = tl.load(
- input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
- mask=offset_k < K,
- other=0,
- ) # [BLOCK_K]
- # N must be divisible by SPLIT_N
- split_n_length = tl.cdiv(N, SPLIT_N)
- if CAST_TYPE:
- tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
- # sliding to next row-block
- b_ptr = (lora_ptr + l0_stride * lora_index +
- pid_sn * split_n_length * lora_k_stride)
- c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length +
- slice_offset * cn_stride)
- for n in range(0, split_n_length, BLOCK_N):
- current_n = n + offset_n
- b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
- < K)
- c_mask = current_n < split_n_length
- tiled_b = tl.load(
- b_ptr + current_n[:, None] * lora_k_stride +
- offset_k[None, :] * lora_n_stride,
- mask=b_ptr_mask,
- other=0.0,
- ) # [BLOCK_N,BLOCK_K]
- if ADD_INPUTS:
- tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
- accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
- else:
- accumulator = tl.sum(tiled_a * tiled_b, 1)
- tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
- @torch.inference_mode()
- def _bgmv_expand_slice(
- inputs: torch.Tensor,
- lora_b_weights: torch.Tensor,
- output_tensor: torch.Tensor,
- lora_indices_tensor: torch.Tensor,
- slice_offset: int,
- slice_size: int,
- add_inputs: bool = True,
- ) -> None:
- """
- Args:
- inputs (torch.Tensor): input tensor
- lora_b_weights (torch.Tensor): lora'b weight
- output_tensor (torch.Tensor): output tensor
- lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
- corresponding to each batch, An index of -1 means no lora should be
- applied.
- slice_offst (int): output_tensor's offst
- slice_size (int): current output_tensor's size
- batches (int): batch size
- add_inputs (bool, optional): Defaults to False.
- """
- assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
- assert lora_b_weights.dtype in [
- torch.float16,
- torch.bfloat16,
- ]
- assert inputs.size(1) == lora_b_weights.size(-1)
- assert slice_size == lora_b_weights.size(-2)
- assert inputs.is_contiguous()
- assert output_tensor.is_contiguous()
- if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
- assert lora_b_weights.size(1) == 1
- lora_b_weights = lora_b_weights.squeeze(dim=1)
- else:
- assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
- assert lora_b_weights.is_contiguous()
- # TODO tuning this config
- N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
- BLOCK_K = triton.next_power_of_2(K)
- EVEN_K = K % BLOCK_K == 0
- ADD_INPUTS = add_inputs
- CAST_TYPE = False
- if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
- torch.float16,
- torch.bfloat16,
- ]:
- CAST_TYPE = True
- batches = lora_indices_tensor.size(0)
- config = get_lora_op_configs("expand", batches, N)
- grid = lambda META: (
- META["SPLIT_N"],
- batches,
- )
- _bgmv_expand_slice_kernel[grid](
- inputs,
- lora_b_weights,
- output_tensor,
- N,
- K,
- lora_indices_tensor,
- inputs.stride(0),
- inputs.stride(1),
- lora_b_weights.stride(0),
- lora_b_weights.stride(1),
- lora_b_weights.stride(2),
- output_tensor.stride(0),
- output_tensor.stride(1),
- slice_offset,
- BLOCK_K=BLOCK_K,
- EVEN_K=EVEN_K,
- ADD_INPUTS=ADD_INPUTS,
- CAST_TYPE=CAST_TYPE,
- **config,
- )
- return
- try:
- bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
- _bgmv_expand_slice,
- mutates_args=["output_tensor"])
- except AttributeError:
- bgmv_expand_slice = _bgmv_expand_slice
|