123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- """
- Based on:
- Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
- Punica: Multi-Tenant LoRA Serving.
- https://arxiv.org/abs/2310.18547
- """
- import torch
- import triton
- import triton.language as tl
- from aphrodite.triton_utils import libentry
- @libentry()
- @triton.jit
- def _sgmv_expand_slice_kernel(
- input_ptr,
- lora_ptr,
- out_ptr,
- N,
- K,
- b_seq_start_loc,
- seq_lens,
- lora_indices,
- xm_stride,
- xk_stride, # 1
- l0_stride, # hidden_size*max_rank
- lora_k_stride,
- lora_n_stride,
- cm_stride,
- cn_stride,
- slice_offset,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
- EVEN_K: tl.constexpr,
- ADD_INPUTS: tl.constexpr,
- CAST_TYPE: tl.constexpr,
- ):
- """
- Similar to the 'sgmv_expand' operator, but with an added parameter
- 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
- might be that in the future, we could implement a fusion operator to
- achieve the current functionality instead of having to call it multiple
- times.
- """
- pid = tl.program_id(axis=0)
- cur_batch = tl.program_id(axis=1)
- cta_n_num = tl.cdiv(N, BLOCK_N)
- pid_m = pid // cta_n_num
- pid_n = pid % cta_n_num
- M = tl.load(seq_lens + cur_batch)
- if pid_m * BLOCK_M > M:
- return
- lora_index = tl.load(lora_indices + cur_batch)
- if lora_index == -1:
- return
- cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
- offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
- offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
- offset_k = tl.arange(0, BLOCK_K)
- ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
- rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
- a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
- offset_k[None, :] * xk_stride, )
- b_ptr = (lora_ptr + l0_stride * lora_index +
- offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
- accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k in range(tl.cdiv(K, BLOCK_K)):
- if EVEN_K:
- tiled_a = tl.load(a_ptr)
- tiled_b = tl.load(b_ptr)
- else:
- tiled_a = tl.load(a_ptr,
- mask=offset_k[None, :] < K - k * BLOCK_K,
- other=0)
- tiled_b = tl.load(b_ptr,
- mask=offset_k[:, None] < K - k * BLOCK_K,
- other=0)
- if CAST_TYPE:
- tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
- accumulator += tl.dot(
- tiled_a,
- tiled_b,
- )
- a_ptr += BLOCK_K * xk_stride
- b_ptr += BLOCK_K * lora_n_stride
- tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
- offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
- offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
- c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
- offset_cn[None, :] * cn_stride)
- M = tl.load(seq_lens + cur_batch)
- c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
- (slice_offset + N))
- if ADD_INPUTS:
- tiled_out = tl.load(c_ptr, mask=c_mask)
- tiled_c += tiled_out
- tl.store(c_ptr, tiled_c, mask=c_mask)
- @torch.inference_mode()
- def _sgmv_expand_slice(
- inputs: torch.Tensor,
- lora_b_weights: torch.Tensor,
- output_tensor: torch.Tensor,
- b_seq_start_loc: torch.Tensor,
- seq_len_tensor: torch.Tensor,
- lora_indices_tensor: torch.Tensor,
- batches: int,
- max_seq_length: int,
- slice_offset: int,
- slice_size: int,
- add_inputs: bool = False,
- ) -> None:
- """_summary_
- Args:
- inputs (torch.Tensor): input tensor
- lora_b_weights (torch.Tensor): lora'a weight
- output_tensor (torch.Tensor): output tensor
- b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
- sequence lengths of the sequences in the batch, used to index
- into sequence. E.g.,if the sequence length is [4, 6], it is
- [0, 4, 10].
- seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
- length of the sequences in the batch
- lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
- corresponding to each batch. An index of -1 means no lora should be
- applied.
- batches (int): batch size
- max_seq_length (int): The max sequence lengths of the sequences
- in the batch
- slice_offst (int): output_tensor's offst
- slice_size (int): current output_tensor's size
- add_inputs (bool, optional): Defaults to False. adds the final lora
- results to the output..
- """
- assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
- assert lora_b_weights.dtype in [
- torch.float16,
- torch.bfloat16,
- ]
- assert inputs.size(1) == lora_b_weights.size(-1)
- assert b_seq_start_loc.size(0) == batches
- assert lora_indices_tensor.size(0) == batches
- assert slice_size == lora_b_weights.size(-2)
- assert inputs.is_contiguous()
- assert output_tensor.is_contiguous()
- if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
- assert lora_b_weights.size(1) == 1
- lora_b_weights = lora_b_weights.squeeze(dim=1)
- else:
- assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
- assert lora_b_weights.is_contiguous()
- # TODO tuning this config
- N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
- BLOCK_M = 32
- BLOCK_N = 32
- BLOCK_K = 16
- EVEN_K = K % BLOCK_K == 0
- ADD_INPUTS = add_inputs
- CAST_TYPE = False
- if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
- torch.float16,
- torch.bfloat16,
- ]:
- CAST_TYPE = True
- grid = (
- triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
- batches,
- )
- _sgmv_expand_slice_kernel[grid](
- inputs,
- lora_b_weights,
- output_tensor,
- N,
- K,
- b_seq_start_loc,
- seq_len_tensor,
- lora_indices_tensor,
- inputs.stride(0),
- inputs.stride(1),
- lora_b_weights.stride(0),
- lora_b_weights.stride(1),
- lora_b_weights.stride(2),
- output_tensor.stride(0),
- output_tensor.stride(1),
- slice_offset,
- BLOCK_M,
- BLOCK_N,
- BLOCK_K,
- EVEN_K,
- ADD_INPUTS,
- CAST_TYPE,
- )
- return
- sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
- _sgmv_expand_slice,
- mutates_args=["output_tensor"])
|