123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343 |
- # pylint: disable=unused-argument
- from typing import TYPE_CHECKING, List, Optional, Union
- import torch
- import torch.nn as nn
- from transformers import PretrainedConfig
- from aphrodite.common.config import LoRAConfig
- from aphrodite.distributed.communication_op import (
- tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
- from aphrodite.distributed.parallel_state import get_tensor_model_parallel_rank
- from aphrodite.lora.layers import (ColumnParallelLinearWithLoRA,
- MergedColumnParallelLinearWithLoRA,
- MergedQKVParallelLinearWithLora,
- QKVParallelLinearWithLora,
- RowParallelLinearWithLoRA)
- if TYPE_CHECKING:
- pass
- def _fully_sharded_can_replace(can_replace):
- """
- decorator which adds the condition of fully sharded loras
- intended to wrap can_replace_layer()
- """
- def dec(*args, **kwargs):
- return (can_replace(*args, **kwargs)
- and kwargs["lora_config"].fully_sharded_loras)
- return dec
- # these layers are based on the tensor parallelism strategy given in
- # Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
- # https://arxiv.org/abs/2311.03285.
- class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
- """
- Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
- Based on S-LoRA, slicing happens along the rank dim.
- """
- def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
- tp_rank = get_tensor_model_parallel_rank()
- shard_size = self.lora_a_stacked.shape[2]
- start_idx = tp_rank * shard_size
- lora_a = lora_a[:, start_idx:start_idx + shard_size]
- return lora_a
- def apply(self, x: torch.Tensor,
- bias: Optional[torch.Tensor]) -> torch.Tensor:
- output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
- x = x.view(-1, x.shape[-1])
- output, out_orig_shape = output.view(-1,
- output.shape[-1]), output.shape
- buffer = torch.zeros(
- (x.shape[0], self.lora_a_stacked.shape[2]),
- dtype=torch.float32,
- device=x.device,
- )
- self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
- buffer = tensor_model_parallel_all_gather(buffer)
- self.punica_wrapper.add_expand(output,
- buffer,
- self.lora_b_stacked,
- add_input=True)
- # now have column partitioned output
- output = output.view(*out_orig_shape)
- return output
- @classmethod
- @_fully_sharded_can_replace
- def can_replace_layer(
- cls,
- source_layer: nn.Module,
- lora_config: LoRAConfig,
- packed_modules_list: List,
- model_config: Optional[PretrainedConfig],
- ) -> bool:
- # specifying kwargs so they can be easily accessed in decorator
- return super().can_replace_layer(
- source_layer=source_layer,
- lora_config=lora_config,
- packed_modules_list=packed_modules_list,
- model_config=model_config,
- decorate=False,
- )
- def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
- """
- MergedColumnParallelLinearWithShardedLoRA and
- MergedQKVParallelLinearWithShardedLora share the same
- LoRa weight application method.
-
- The main difference is the step by shard_size for lora_b which can
- vary for MergedQKVParallelLinearWithShardedLora but is constant for
- MergedColumnParallelLinearWithShardedLoRA.
- """
- # expecting 2 for column parallel and 3 for qkv
- n = len(layer.lora_a_stacked)
- output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
- x = x.view(-1, x.shape[-1])
- output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
- buffers = torch.zeros(
- (n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
- dtype=torch.float32,
- device=x.device,
- )
- for idx in range(n):
- layer.punica_wrapper.add_shrink(buffers[idx], x,
- layer.lora_a_stacked[idx], 1.0)
- buffers = tensor_model_parallel_all_gather(buffers)
- left_offset = 0
- for idx in range(n):
- shard_size = layer.lora_b_stacked[idx].shape[2]
- layer.punica_wrapper.add_expand_slice(
- output,
- buffers[idx],
- layer.lora_b_stacked[idx],
- left_offset,
- shard_size,
- add_input=True,
- )
- left_offset += shard_size
- output = output.view(*out_orig_shape)
- # now have column partitioned and packed output
- return output
- class MergedColumnParallelLinearWithShardedLoRA(
- MergedColumnParallelLinearWithLoRA):
- """
- Differs from MergedColumnParallelLinearWithLoRA by slicing the
- LoRA A's also.
- Based on S-LoRA, slicing happens along the rank dim.
- """
- def slice_lora_a(
- self, lora_a: List[Union[torch.Tensor, None]]
- ) -> List[Union[torch.Tensor, None]]:
- if lora_a[0] is None or lora_a[1] is None:
- return lora_a
- output_shard_size = self.lora_a_stacked[0].shape[2]
- output_start_idx = self.tp_rank * output_shard_size
- lora_a = [
- lora_a[0][:,
- output_start_idx:output_start_idx + output_shard_size],
- lora_a[1][:,
- output_start_idx:output_start_idx + output_shard_size],
- ]
- return lora_a
- def apply(self, x: torch.Tensor,
- bias: Optional[torch.Tensor]) -> torch.Tensor:
- return _mcp_apply(x, bias, self)
- @classmethod
- @_fully_sharded_can_replace
- def can_replace_layer(
- cls,
- source_layer: nn.Module,
- lora_config: LoRAConfig,
- packed_modules_list: List,
- model_config: Optional[PretrainedConfig],
- ) -> bool:
- # specifying kwargs so they can be easily accessed in decorator
- return super().can_replace_layer(
- source_layer=source_layer,
- lora_config=lora_config,
- packed_modules_list=packed_modules_list,
- model_config=model_config,
- decorate=False,
- )
- class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
- """
- Differs from QKVParallelLinearWithLora by slicing the
- LoRA A's also.
- Based on S-LoRA, slicing happens along the rank dim.
- """
- def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
- tp_rank = get_tensor_model_parallel_rank()
- shard_size = self.lora_a_stacked.shape[2]
- start_idx = tp_rank * shard_size
- lora_a = lora_a[:, start_idx:start_idx + shard_size]
- return lora_a
- def apply(self, x: torch.Tensor,
- bias: Optional[torch.Tensor]) -> torch.Tensor:
- output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
- x = x.view(-1, x.shape[-1])
- output, out_orig_shape = output.view(-1,
- output.shape[-1]), output.shape
- buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
- dtype=torch.float32,
- device=x.device)
- self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
- buffer = tensor_model_parallel_all_gather(buffer)
- self.punica_wrapper.add_expand(output,
- buffer,
- self.lora_b_stacked,
- add_input=True)
- # now have column partitioned output
- output = output.view(*out_orig_shape)
- return output
- @classmethod
- @_fully_sharded_can_replace
- def can_replace_layer(cls, source_layer: nn.Module,
- lora_config: LoRAConfig, packed_modules_list: List,
- model_config: Optional[PretrainedConfig]) -> bool:
- # specifying kwargs so they can be easily accessed in decorator
- return super().can_replace_layer(
- source_layer=source_layer,
- lora_config=lora_config,
- packed_modules_list=packed_modules_list,
- model_config=model_config,
- decorate=False,
- )
- class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
- """
- Differs from MergedQKVParallelLinearWithLora by slicing the
- LoRA A's also.
- Based on S-LoRA, slicing happens along the rank dim.
- """
- def slice_lora_a(
- self, lora_a: List[Union[torch.Tensor, None]]
- ) -> List[Union[torch.Tensor, None]]:
- if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None:
- return lora_a
- shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
- start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
- lora_a = [
- lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
- lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
- lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]],
- ]
- return lora_a
- def apply(self, x: torch.Tensor,
- bias: Optional[torch.Tensor]) -> torch.Tensor:
- return _mcp_apply(x, bias, self)
- @classmethod
- @_fully_sharded_can_replace
- def can_replace_layer(
- cls,
- source_layer: nn.Module,
- lora_config: LoRAConfig,
- packed_modules_list: List,
- model_config: Optional[PretrainedConfig],
- ) -> bool:
- # specifying kwargs so they can be easily accessed in decorator
- return super().can_replace_layer(
- source_layer=source_layer,
- lora_config=lora_config,
- packed_modules_list=packed_modules_list,
- model_config=model_config,
- decorate=False,
- )
- class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
- """
- Differs from RowParallelLinearWithLoRA by slicing the
- LoRA B's also.
- Based on S-LoRA, slicing happens along the output dim.
- This yields a combined partial sum from the row parallel base
- layer and column partitioned output from the LoRA.
- """
- def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
- shard_size = self.lora_b_stacked.shape[2]
- start_idx = self.tp_rank * shard_size
- end_idx = (self.tp_rank + 1) * shard_size
- lora_b = lora_b[:, start_idx:end_idx]
- return lora_b
- def apply(self, x: torch.Tensor) -> torch.Tensor:
- output = self.base_layer.quant_method.apply(self.base_layer, x)
- x = x.view(-1, x.shape[-1])
- output, out_orig_shape = output.view(-1,
- output.shape[-1]), output.shape
- buffer = torch.zeros(
- (x.shape[0], self.lora_a_stacked.shape[2]),
- dtype=torch.float32,
- device=x.device,
- )
- self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
- buffer = tensor_model_parallel_all_reduce(buffer)
- # following S-LoRA, allows the fusing of all_gather and all_reduce
- # by adding the column partitioned lora output to a slice of output
- # tensor, which is a partial sum due to row parallel. All that
- # remains is a standard all_reduce. User should be aware though that
- # the output is not the same as a normal row_parallel, it should be
- # reduced before being used
- shard_size = self.lora_b_stacked.shape[2]
- start_idx = self.tp_rank * shard_size
- self.punica_wrapper.add_expand_slice(output, buffer,
- self.lora_b_stacked, start_idx,
- shard_size)
- output = output.view(*out_orig_shape)
- return output
- @classmethod
- @_fully_sharded_can_replace
- def can_replace_layer(
- cls,
- source_layer: nn.Module,
- lora_config: LoRAConfig,
- packed_modules_list: List,
- model_config: Optional[PretrainedConfig],
- ) -> bool:
- # specifying kwargs so they can be easily accessed in decorator
- return super().can_replace_layer(
- source_layer=source_layer,
- lora_config=lora_config,
- packed_modules_list=packed_modules_list,
- model_config=model_config,
- decorate=False,
- )
|