# pylint: disable=unused-argument import math from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig from aphrodite.common.config import LoRAConfig from aphrodite.lora.punica import add_lora, add_lora_slice, bgmv from aphrodite.modeling.layers.sampler import Sampler from aphrodite.modeling.megatron.communication_op import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, tensor_model_parallel_gather, ) from aphrodite.modeling.layers.linear import (ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear) from aphrodite.modeling.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from aphrodite.modeling.megatron.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from aphrodite.modeling.megatron.utils import split_tensor_along_last_dim if TYPE_CHECKING: pass def _get_lora_device(base_layer: nn.Module) -> torch.device: """Identify the device for positioning the LoRA tensors.""" device = None try: device = base_layer.weight.device except AttributeError: try: linear_weights = base_layer.linear_weights if isinstance(linear_weights, dict): tensor_values = [ v for v in linear_weights.values() if isinstance(v, torch.Tensor) ] if tensor_values: device = tensor_values[0].device except AttributeError: pass if device is None: raise ValueError(f"Base layer not supported: {base_layer}") return device def _apply_lora( x: torch.Tensor, lora_a_stacked: torch.Tensor, lora_b_stacked: torch.Tensor, indices: torch.Tensor, output: torch.Tensor, ): """Applies lora to each input. This method applies all loras to each input. It uses the indices vector to determine which lora yields the correct output. An index of -1 means no lora should be applied. This method adds the final lora results to the output. Input shapes: x: (batch_size, hidden_dim) lora_a_stacked: (num_loras, lora_rank, hidden_dim) lora_b_stacked: (num_loras, output_dim, lora_rank) indices: (batch_size) output: (batch_size, output_dim) """ org_output = output x = x.view(-1, x.shape[-1]) output = output.view(-1, output.shape[-1]) indices = indices.view(-1) add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) return output.view_as(org_output) def _apply_lora_packed_nslice( x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], indices: torch.Tensor, output: torch.Tensor, output_slices: Tuple[int, ...], ): """Applies lora to each input. This method applies all loras to each input. It uses the indices vector to determine which lora yields the correct output. An index of -1 means no lora should be applied. This method adds the final lora results to the output. This method is used for layers that are composed of multiple sublayers (slices) packed together. Input shapes: x: (batch_size, hidden_dim) lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim) lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) indices: (batch_size) output: (batch_size, q_slice_size + 2*kv_slice_size) output_slices: n-1 element tuple of (slice_size...), where n is number of slices """ org_output = output x = x.view(-1, x.shape[-1]) output = output.view(-1, output.shape[-1]) indices = indices.view(-1) offset_left = 0 for slice_idx in range(len(output_slices)): add_lora_slice(output, x, lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, output_slices[slice_idx]) offset_left += output_slices[slice_idx] return output.view_as(org_output) @dataclass class LoRAMapping: # Per every token in input_ids: index_mapping: Tuple[int, ...] # Per sampled token: prompt_mapping: Tuple[int, ...] def __post_init__(self): self.index_mapping = tuple(self.index_mapping) self.prompt_mapping = tuple(self.prompt_mapping) class BaseLayerWithLoRA(nn.Module): def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, model_config: PretrainedConfig) -> None: """Initializes lora matrices.""" ... def reset_lora(self, index: int): """Resets the lora weights at index back to 0.""" ... def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], ): """Overwrites lora tensors at index.""" ... def set_mapping( self, base_indices: torch.Tensor, sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, indices_len: List[int], ): """Sets the mapping indices.""" ... class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() self.base_layer = base_layer def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: lora_vocab_start_idx = self.base_layer.org_vocab_size weights_idx = None if self.base_layer.vocab_end_index > lora_vocab_start_idx: # We can start adding lora weights weights_idx = max( lora_vocab_start_idx - self.base_layer.vocab_start_index, 0) self.embeddings_slice = (self.base_layer.vocab_start_index - self.base_layer.org_vocab_size + weights_idx, self.base_layer.vocab_end_index - self.base_layer.org_vocab_size) self.embeddings_weights = self.base_layer.weight.data[weights_idx:] self.embeddings_weights.fill_(0) else: self.embeddings_slice = None self.embeddings_weights = None self.embeddings_tensors = torch.zeros( ( max_loras, lora_config.lora_extra_vocab_size, self.base_layer.embedding_dim, ), dtype=self.base_layer.weight.dtype, device=self.base_layer.weight.device, ) self.lora_a_stacked = torch.zeros( ( max_loras, self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, device=self.base_layer.weight.device, ) self.lora_b_stacked = torch.zeros( ( max_loras, 1, self.base_layer.embedding_dim, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, device=self.base_layer.weight.device, ) self.lora_a_stacked_2d = self.lora_a_stacked.view( self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], self.lora_a_stacked.shape[2], ) self.indices: Optional[torch.Tensor] = None self.indices_len: Optional[List[int]] = None self.embeddings_indices = None def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 self.embeddings_tensors[index] = 0 def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( lora_a, non_blocking=True) self.lora_b_stacked[index, 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) if embeddings_tensor is not None: self.embeddings_tensors[ index, :embeddings_tensor.shape[0], :embeddings_tensor. shape[1]].copy_(embeddings_tensor, non_blocking=True) if self.embeddings_slice is not None: # TODO(yard1): Optimize this copy, we don't need to copy # everything, just the modified part embeddings = self.embeddings_tensors.view( self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1], self.embeddings_tensors.shape[2] )[self.embeddings_slice[0]:self.embeddings_slice[1]] self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) def set_mapping( self, base_indices: torch.Tensor, sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, indices_len: List[int], ): self.indices = base_indices self.embeddings_indices = embeddings_indices self.indices_len = indices_len def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, self.lora_a_stacked_2d, ) indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) full_output = self.base_layer.forward( x.add_(indices * added_tokens_mask)) full_output_org = full_output if full_output.ndim == 3: full_output = full_output.view( full_output.shape[0] * full_output.shape[1], -1) if full_lora_a_embeddings.ndim == 3: full_lora_a_embeddings = full_lora_a_embeddings.view( full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1], -1) bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, self.indices[:self.indices_len[0]], 0, 1.0) return full_output.view_as(full_output_org) class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: ColumnParallelLinear) -> None: super().__init__() self.base_layer = base_layer def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: self.lora_a_stacked = torch.zeros( max_loras, 1, lora_config.max_lora_rank, self.base_layer.weight.shape[1], dtype=lora_config.lora_dtype, device=self.base_layer.weight.device, ) self.lora_b_stacked = torch.zeros( max_loras, 1, self.base_layer.weight.shape[0], lora_config.max_lora_rank, dtype=lora_config.lora_dtype, device=self.base_layer.weight.device, ) self.indices: Optional[torch.Tensor] = None self.indices_len: Optional[List[int]] = None self.output_dim = self.lora_b_stacked.shape[1] def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( lora_a.T, non_blocking=True) self.lora_b_stacked[index, 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) def set_mapping( self, base_indices: torch.Tensor, sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, indices_len: List[int], ): self.indices = base_indices self.indices_len = indices_len def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( self.base_layer.linear_weights, x, bias) _apply_lora( x, self.lora_a_stacked, self.lora_b_stacked, self.indices[:self.indices_len[0]], output, ) return output def forward(self, input_): """Forward of ColumnParallelLinear Args: input_: Tensor whose last dimension is `input_size`. Returns: - output - bias """ bias = (self.base_layer.bias if not self.base_layer.skip_bias_add else None) # Matrix multiply. output_parallel = self.apply_weights(input_, bias) if self.base_layer.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) else: output = output_parallel output_bias = (self.base_layer.bias if self.base_layer.skip_bias_add else None) return output, output_bias @property def linear_weights(self): return self.base_layer.linear_weights class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): """ColumnParallelLinear layer that is composed of 2 sublayers (slices) packed together (eg. gate_proj + up_proj -> gate_up_proj). This means we have 2 LoRAs, each applied to one half of the layer. Both slices must have the same size. """ def __init__(self, base_layer: MergedColumnParallelLinear) -> None: super().__init__(base_layer) def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: n_slices = 2 if not (len(self.base_layer.output_sizes) == n_slices and self.base_layer.output_sizes[0] == self.base_layer.output_sizes[1]): raise ValueError( "LoRAColumnParallelLinear2Slice requires 2 slices with " "the same size.") self.tp_size = get_tensor_model_parallel_world_size() device = _get_lora_device(self.base_layer) self.lora_a_stacked = tuple( torch.zeros( max_loras, 1, lora_config.max_lora_rank, self.base_layer.input_size, dtype=lora_config.lora_dtype, device=device, ) for _ in range(n_slices)) self.lora_b_stacked = tuple( torch.zeros( max_loras, 1, self.base_layer.output_size // 2, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, device=device, ) for _ in range(n_slices)) self.indices: Optional[torch.Tensor] = None self.output_dim = self.lora_b_stacked[0].shape[2] def reset_lora(self, index: int): self.lora_a_stacked[0][index] = 0 self.lora_a_stacked[1][index] = 0 self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[1][index] = 0 def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) if self.tp_size > 1: tensor_model_parallel_rank = get_tensor_model_parallel_rank() shard_size = self.output_dim start_idx = tensor_model_parallel_rank * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size lora_b = lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx] if lora_a[0] is not None: self.lora_a_stacked[0][ index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( lora_a[0].T, non_blocking=True) self.lora_b_stacked[0][ index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( lora_b[0].T, non_blocking=True) if lora_a[1] is not None: self.lora_a_stacked[1][ index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( lora_a[1].T, non_blocking=True) self.lora_b_stacked[1][ index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( lora_b[1].T, non_blocking=True) def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( self.base_layer.linear_weights, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, self.lora_b_stacked, self.indices[:self.indices_len[0]], output, (self.output_dim, self.output_dim), ) return output class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): """ColumnParallelLinear layer that is composed of 3 sublayers (slices) packed together in qkv proj fashion (q_proj + k_proj + v_proj -> qkv_proj). This means we have 3 LoRAs, each applied to one slice of the layer. Q slice may have different shape than K and V slices (which both have the same shape). """ def __init__(self, base_layer: QKVParallelLinear) -> None: super().__init__(base_layer) def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() self.q_proj_shard_size = (self.base_layer.num_heads * self.base_layer.head_size) self.kv_proj_shard_size = (self.base_layer.num_kv_heads * self.base_layer.head_size) self.q_shard_id = tp_rank self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas device = _get_lora_device(self.base_layer) # q, k, v self.lora_a_stacked = ( torch.zeros( max_loras, 1, lora_config.max_lora_rank, self.base_layer.input_size, dtype=lora_config.lora_dtype, device=device, ), torch.zeros( max_loras, 1, lora_config.max_lora_rank, self.base_layer.input_size, dtype=lora_config.lora_dtype, device=device, ), torch.zeros( max_loras, 1, lora_config.max_lora_rank, self.base_layer.input_size, dtype=lora_config.lora_dtype, device=device, ), ) self.lora_b_stacked = ( torch.zeros( max_loras, 1, self.q_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, device=device, ), torch.zeros( max_loras, 1, self.kv_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, device=device, ), torch.zeros( max_loras, 1, self.kv_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, device=device, ), ) self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size, self.kv_proj_shard_size) self.packed_indices: Optional[torch.Tensor] = None self.standard_indices: Optional[torch.Tensor] = None self.indices_len: Optional[List[int]] = None def reset_lora(self, index: int): self.lora_a_stacked[0][index] = 0 self.lora_b_stacked[0][index] = 0 self.lora_a_stacked[1][index] = 0 self.lora_b_stacked[1][index] = 0 self.lora_a_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0 def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) if self.tp_size > 1: if lora_b[0] is not None: lora_b_q = lora_b[0][:, self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size * (self.q_shard_id + 1)] self.lora_b_stacked[0][ index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( lora_b_q.T, non_blocking=True) if lora_b[1] is not None: lora_b_k = lora_b[1][:, self.kv_proj_shard_size * self.kv_shard_id:self.kv_proj_shard_size * (self.kv_shard_id + 1)] self.lora_b_stacked[1][ index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( lora_b_k.T, non_blocking=True) if lora_b[2] is not None: lora_b_v = lora_b[2][:, self.kv_proj_shard_size * self.kv_shard_id:self.kv_proj_shard_size * (self.kv_shard_id + 1)] self.lora_b_stacked[2][ index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( lora_b_v.T, non_blocking=True) else: if lora_b[0] is not None: self.lora_b_stacked[0][ index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( lora_b[0].T, non_blocking=True) if lora_b[1] is not None: self.lora_b_stacked[1][ index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( lora_b[1].T, non_blocking=True) if lora_b[2] is not None: self.lora_b_stacked[2][ index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_( lora_b[2].T, non_blocking=True) if lora_a[0] is not None: self.lora_a_stacked[0][ index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( lora_a[0].T, non_blocking=True) if lora_a[1] is not None: self.lora_a_stacked[1][ index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( lora_a[1].T, non_blocking=True) if lora_a[2] is not None: self.lora_a_stacked[2][ index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( lora_a[2].T, non_blocking=True) def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( self.base_layer.linear_weights, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, self.lora_b_stacked, self.indices[:self.indices_len[0]], output, self.output_slices, ) return output class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: super().__init__() self.base_layer = base_layer def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: device = _get_lora_device(self.base_layer) self.lora_a_stacked = torch.zeros( ( max_loras, 1, lora_config.max_lora_rank, self.base_layer.input_size, ), dtype=lora_config.lora_dtype, device=device, ) self.lora_b_stacked = torch.zeros( ( max_loras, 1, self.base_layer.output_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, device=device, ) self.indices: Optional[torch.Tensor] = None self.indices_len: Optional[List[int]] = None def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) if self.base_layer.tp_size > 1: tensor_model_parallel_rank = get_tensor_model_parallel_rank() shard_size = self.base_layer.weight.shape[1] start_idx = tensor_model_parallel_rank * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size lora_a = lora_a[start_idx:end_idx, :] self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( lora_a.T, non_blocking=True) self.lora_b_stacked[index, 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) def set_mapping( self, base_indices: torch.Tensor, sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, indices_len: List[int], ): self.indices = base_indices self.indices_len = indices_len def apply_weights(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( self.base_layer.linear_weights, x) _apply_lora( x, self.lora_a_stacked, self.lora_b_stacked, self.indices[:self.indices_len[0]], output, ) return output def forward(self, input_): """Forward of RowParallelLinear Args: input_: tensor whose last dimension is `input_size`. If `input_is_parallel` is set, then the last dimension is `input_size // tp_size`. Returns: - output - bias """ # Set up backprop all-reduce. if self.base_layer.input_is_parallel: input_parallel = input_ else: # TODO: simplify code below tp_rank = get_tensor_model_parallel_rank() splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.base_layer.tp_size) input_parallel = splitted_input[tp_rank].contiguous() # Matrix multiply. output_parallel = self.apply_weights(input_parallel) if self.base_layer.reduce_results and self.base_layer.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: output_ = output_parallel if not self.base_layer.skip_bias_add: output = (output_ + self.base_layer.bias if self.base_layer.bias is not None else output_) output_bias = None else: output = output_ output_bias = self.base_layer.bias return output, output_bias @property def weight(self): return self.base_layer.weight class SamplerWithLoRA(BaseLayerWithLoRA): def __init__( self, base_layer: Sampler, hidden_size: int, dtype: torch.dtype, device: torch.device, ) -> None: super().__init__() self.base_layer = base_layer self.hidden_size = hidden_size self.dtype = dtype self.device = device @property def vocab_size(self): return self.base_layer.vocab_size @property def org_vocab_size(self): return self.base_layer.org_vocab_size @property def include_gpu_probs_tensor(self): return self.base_layer.include_gpu_probs_tensor def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: # Keep this in sync with csrc/punica/bgmv/bgmv_config.h if 32000 < self.base_layer.vocab_size > 33024: raise ValueError( "When using LoRA, vocab size must be 32000 >= vocab_size " "<= 33024") self.lora_a_stacked = torch.zeros( ( max_loras, 1, lora_config.max_lora_rank, self.hidden_size, ), dtype=lora_config.lora_dtype, device=self.device, ) self.lora_b_stacked = torch.zeros( ( max_loras, 1, # Pad for kernel compatibility math.ceil(self.base_layer.vocab_size / lora_config.lora_vocab_padding_size) * lora_config.lora_vocab_padding_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, device=self.device, ) self.embeddings_tensors = torch.full( (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), fill_value=float("-inf"), dtype=self.dtype, device=self.device, ) self.indices = None self.indices_padded = None self.indices_len = None def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 self.embeddings_tensors[index] = float("-inf") def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( lora_a.T, non_blocking=True) self.lora_b_stacked[index, 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) if embeddings_tensor is not None: self.embeddings_tensors[ index, :embeddings_tensor.shape[0], :embeddings_tensor. shape[1], ] = embeddings_tensor def set_mapping( self, base_indices: torch.Tensor, sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, indices_len: List[int], ): self.indices = sampler_indices self.indices_padded = sampler_indices_padded self.indices_len = indices_len def _get_logits( self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: logits += embedding_bias logits = tensor_model_parallel_gather(logits) if logits is None: return None lora_logits = torch.empty( self.embeddings_tensors.shape[0] + 1, self.embeddings_tensors.shape[1], hidden_states.shape[0], dtype=self.embeddings_tensors.dtype, device=self.embeddings_tensors.device, ) torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1]) lora_logits[-1] = float("-inf") lora_logits = lora_logits.mT lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], ).index_select(0, self.indices_padded[:self.indices_len[2]]).nan_to_num_( nan=float("-inf"), posinf=float("inf"), neginf=float("-inf"))) logits[:, self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits _apply_lora( hidden_states, self.lora_a_stacked, self.lora_b_stacked, self.indices[:self.indices_len[1]], logits, ) # Remove paddings in vocab (if any). logits = logits[:, :self.base_layer.vocab_size] return logits def forward(self, *args, **kwargs): return type(self.base_layer).forward(self, *args, **kwargs) def from_layer( layer: nn.Module, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA: supported_layer_types = { VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, ColumnParallelLinear: ColumnParallelLinearWithLoRA, QKVParallelLinear: QKVParallelLinearWithLora, MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, RowParallelLinear: RowParallelLinearWithLoRA, } for src_layer_type, lora_layer_type in supported_layer_types.items(): if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck ret = lora_layer_type(layer) ret.create_lora_weights(max_loras, lora_config, model_config) return ret return layer def from_layer_sampler( layer: Sampler, lm_head: ParallelLMHead, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> SamplerWithLoRA: ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype, lm_head.weight.device) ret.create_lora_weights(max_loras, lora_config, model_config) return ret