""" Based on: Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ import torch import triton import triton.language as tl from .utils import get_lora_op_configs @triton.jit def _bgmv_expand_kernel( input_ptr, lora_ptr, out_ptr, N, K, lora_indices, xm_stride, xk_stride, l0_stride, lora_k_stride, lora_n_stride, cm_stride, cn_stride, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, SPLIT_N: tl.constexpr, EVEN_K: tl.constexpr, ADD_INPUTS: tl.constexpr, CAST_TYPE: tl.constexpr, ): """ GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's performance """ pid_sn = tl.program_id(axis=0) cur_batch = tl.program_id(axis=1) lora_index = tl.load(lora_indices + cur_batch) if lora_index == -1: return offset_k = tl.arange(0, BLOCK_K) offset_n = tl.arange(0, BLOCK_N) if EVEN_K: tiled_a = tl.load(input_ptr + cur_batch * xm_stride + offset_k * xk_stride, ) # [BLOCK_K] else: tiled_a = tl.load( input_ptr + cur_batch * xm_stride + offset_k * xk_stride, mask=offset_k < K, other=0, ) # [BLOCK_K] # N must be divisible by SPLIT_N split_n_length = tl.cdiv(N, SPLIT_N) if CAST_TYPE: tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) # sliding to next row-block b_ptr = (lora_ptr + l0_stride * lora_index + pid_sn * split_n_length * lora_k_stride) c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length for n in range(0, split_n_length, BLOCK_N): current_n = n + offset_n current_n_c = tl.max_contiguous(current_n, BLOCK_N) b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] < K) c_mask = current_n < split_n_length tiled_b = tl.load( b_ptr + current_n_c[:, None] * lora_k_stride + offset_k[None, :] * lora_n_stride, mask=b_ptr_mask, other=0.0, ) # [BLOCK_N,BLOCK_K] if ADD_INPUTS: tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out else: accumulator = tl.sum(tiled_a * tiled_b, 1) tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) @torch.inference_mode() def _bgmv_expand( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True, ) -> None: """ Args: inputs (torch.Tensor): input tensor lora_b_weights (torch.Tensor): lora'a weight output_tensor (torch.Tensor): output tensor lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch, An index of -1 means no lora should be applied. batches (int): batch size add_inputs (bool, optional): Defaults to False. adds the final lora results to the output. """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert lora_b_weights.dtype in [ torch.float16, torch.bfloat16, ] assert inputs.size(1) == lora_b_weights.size(-1) assert inputs.is_contiguous() assert output_tensor.is_contiguous() if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) assert lora_b_weights.size(1) == 1 lora_b_weights = lora_b_weights.squeeze(dim=1) else: assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) assert lora_b_weights.is_contiguous() # TODO tuning this config N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size BLOCK_K = triton.next_power_of_2(K) EVEN_K = K % BLOCK_K == 0 ADD_INPUTS = add_inputs CAST_TYPE = False if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ torch.float16, torch.bfloat16, ]: CAST_TYPE = True batches = lora_indices_tensor.size(0) config = get_lora_op_configs("expand", batches, N) grid = lambda META: ( META["SPLIT_N"], batches, ) _bgmv_expand_kernel[grid]( inputs, lora_b_weights, output_tensor, N, K, lora_indices_tensor, inputs.stride(0), inputs.stride(1), lora_b_weights.stride(0), lora_b_weights.stride(1), lora_b_weights.stride(2), output_tensor.stride(0), output_tensor.stride(1), BLOCK_K=BLOCK_K, EVEN_K=EVEN_K, ADD_INPUTS=ADD_INPUTS, CAST_TYPE=CAST_TYPE, **config, ) return try: bgmv_expand = torch.library.custom_op("lora::bgmv_expand", _bgmv_expand, mutates_args=["output_tensor"]) except AttributeError: bgmv_expand = _bgmv_expand