123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391 |
- from typing import Callable, List, Optional, Tuple, Union
- import torch
- from torch.nn import Parameter
- from aphrodite.distributed import get_tensor_model_parallel_rank
- __all__ = [
- "BaseAphroditeParameter", "PackedAphroditeParameter",
- "PerTensorScaleParameter", "ModelWeightParameter",
- "ChannelQuantScaleParameter", "GroupQuantScaleParameter",
- "PackedColumnParameter", "RowAphroditeParameter"
- ]
- class BaseAphroditeParameter(Parameter):
- """
- Base parameter for Aphrodite linear layers. Extends the torch.nn.parameter
- by taking in a linear weight loader. Will copy the loaded weight
- into the parameter when the provided weight loader is called.
- """
- def __new__(cls, data: torch.Tensor, **kwargs):
- return super().__new__(cls, data=data, requires_grad=False)
- def __init__(self, data: torch.Tensor, weight_loader: Callable):
- """
- Initialize the BaseAphroditeParameter
- :param data: torch tensor with the parameter data
- :param weight_loader: weight loader callable
- :returns: a torch.nn.parameter
- """
- self._weight_loader = weight_loader
- @property
- def weight_loader(self):
- return self._weight_loader
- def _assert_and_load(self, loaded_weight: torch.Tensor):
- assert self.data.shape == loaded_weight.shape
- self.data.copy_(loaded_weight)
- def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
- self._assert_and_load(loaded_weight)
- def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
- self._assert_and_load(loaded_weight)
- def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
- self._assert_and_load(loaded_weight)
- def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
- self._assert_and_load(loaded_weight)
- class _ColumnAphroditeParameter(BaseAphroditeParameter):
- """
- Private class defining weight loading functionality
- (load_merged_column_weight, load_qkv_weight)
- for parameters being loaded into linear layers with column
- parallelism. This includes QKV and MLP layers which are
- not already fused on disk. Requires an output dimension
- to be defined. Called within the weight loader of
- each of the column parallel linear layers.
- """
- def __init__(self, output_dim: int, **kwargs):
- self._output_dim = output_dim
- super().__init__(**kwargs)
- @property
- def output_dim(self):
- return self._output_dim
- def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
- tp_rank = get_tensor_model_parallel_rank()
- shard_size = self.data.shape[self.output_dim]
- loaded_weight = loaded_weight.narrow(self.output_dim,
- tp_rank * shard_size, shard_size)
- assert self.data.shape == loaded_weight.shape
- self.data.copy_(loaded_weight)
- def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
- shard_offset = kwargs.get("shard_offset")
- shard_size = kwargs.get("shard_size")
- if isinstance(
- self,
- (PackedColumnParameter,
- PackedAphroditeParameter)) and self.packed_dim == self.output_dim:
- shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
- shard_offset=shard_offset, shard_size=shard_size)
- param_data = self.data
- tp_rank = get_tensor_model_parallel_rank()
- param_data = param_data.narrow(self.output_dim, shard_offset,
- shard_size)
- loaded_weight = loaded_weight.narrow(self.output_dim,
- tp_rank * shard_size, shard_size)
- assert param_data.shape == loaded_weight.shape
- param_data.copy_(loaded_weight)
- def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
- shard_offset = kwargs.get("shard_offset")
- shard_size = kwargs.get("shard_size")
- shard_id = kwargs.get("shard_id")
- num_heads = kwargs.get("num_heads")
- if isinstance(
- self,
- (PackedColumnParameter,
- PackedAphroditeParameter)) and self.output_dim == self.packed_dim:
- shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
- shard_offset=shard_offset, shard_size=shard_size)
- param_data = self.data
- tp_rank = get_tensor_model_parallel_rank()
- shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
- param_data = param_data.narrow(self.output_dim, shard_offset,
- shard_size)
- loaded_weight = loaded_weight.narrow(self.output_dim,
- shard_id * shard_size, shard_size)
- assert param_data.shape == loaded_weight.shape
- param_data.copy_(loaded_weight)
- class RowAphroditeParameter(BaseAphroditeParameter):
- """
- Parameter class defining weight_loading functionality
- (load_row_parallel_weight) for parameters being loaded
- into linear layers with row parallel functionality.
- Requires an input_dim to be defined.
- """
- def __init__(self, input_dim: int, **kwargs):
- self._input_dim = input_dim
- super().__init__(**kwargs)
- @property
- def input_dim(self):
- return self._input_dim
- def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
- tp_rank = get_tensor_model_parallel_rank()
- shard_size = self.data.shape[self.input_dim]
- loaded_weight = loaded_weight.narrow(self.input_dim,
- tp_rank * shard_size, shard_size)
- if len(loaded_weight.shape) == 0:
- loaded_weight = loaded_weight.reshape(1)
- assert self.data.shape == loaded_weight.shape
- self.data.copy_(loaded_weight)
- class ModelWeightParameter(_ColumnAphroditeParameter, RowAphroditeParameter):
- """
- Parameter class for linear layer weights. Uses both column and
- row parallelism.
- """
- pass
- class GroupQuantScaleParameter(_ColumnAphroditeParameter,
- RowAphroditeParameter):
- """
- Parameter class for weight scales loaded for weights with
- grouped quantization. Uses both column and row parallelism.
- """
- pass
- class ChannelQuantScaleParameter(_ColumnAphroditeParameter):
- """
- Parameter class for weight scales loaded for weights with
- channel-wise quantization. Equivalent to _ColumnAphroditeParameter.
- """
- pass
- class PerTensorScaleParameter(BaseAphroditeParameter):
- """
- Parameter class for scales where the number of scales is
- equivalent to the number of logical matrices in fused linear
- layers (e.g. for QKV, there are 3 scales loaded from disk).
- This is relevant to weights with per-tensor quantization.
- Adds functionality to map the scalers to a shard during
- weight loading.
- Note: additional parameter manipulation may be handled
- for each quantization config specifically, within
- process_weights_after_loading
- """
- def __init__(self, **kwargs):
- self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
- super().__init__(**kwargs)
- def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
- if isinstance(shard_id, int):
- return shard_id
- assert isinstance(shard_id, str)
- assert shard_id in self.qkv_idxs
- return self.qkv_idxs[shard_id]
- def load_merged_column_weight(self, *args, **kwargs):
- self._load_into_shard_id(*args, **kwargs)
- def load_qkv_weight(self, *args, **kwargs):
- self._load_into_shard_id(*args, **kwargs)
- def load_column_parallel_weight(self, *args, **kwargs):
- self._load_into_shard_id(*args, **kwargs)
- def _load_into_shard_id(self, loaded_weight: torch.Tensor,
- shard_id: Union[str, int], **kwargs):
- """
- Slice the parameter data based on the shard id for
- loading.
- """
- param_data = self.data
- shard_id = self._shard_id_as_int(shard_id)
- # AutoFP8 scales do not have a shape
- # compressed-tensors scales do have a shape
- if len(loaded_weight.shape) != 0:
- assert loaded_weight.shape[0] == 1
- loaded_weight = loaded_weight[0]
- param_data = param_data[shard_id]
- assert param_data.shape == loaded_weight.shape
- param_data.copy_(loaded_weight)
- class PackedColumnParameter(_ColumnAphroditeParameter):
- """
- Parameter for model parameters which are packed on disk
- and support column parallelism only. See PackedAphroditeParameter
- for more details on the packed properties.
- """
- def __init__(self,
- packed_factor: int,
- packed_dim: int,
- marlin_tile_size: Optional[int] = None,
- **kwargs):
- self._packed_factor = packed_factor
- self._packed_dim = packed_dim
- self._marlin_tile_size = marlin_tile_size
- super().__init__(**kwargs)
- @property
- def packed_dim(self):
- return self._packed_dim
- @property
- def packed_factor(self):
- return self._packed_factor
- @property
- def marlin_tile_size(self):
- return self._marlin_tile_size
- def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
- return _adjust_shard_indexes_for_packing(
- shard_size=shard_size,
- shard_offset=shard_offset,
- packed_factor=self.packed_factor,
- marlin_tile_size=self.marlin_tile_size)
- class PackedAphroditeParameter(ModelWeightParameter):
- """
- Parameter for model weights which are packed on disk.
- Example: GPTQ Marlin weights are int4 or int8, packed into int32.
- Extends the ModelWeightParameter to take in the
- packed factor, the packed dimension, and optionally, marlin
- tile size for marlin kernels. Adjusts the shard_size and
- shard_offset for fused linear layers model weight loading
- by accounting for packing and optionally, marlin tile size.
- """
- def __init__(self,
- packed_factor: int,
- packed_dim: int,
- marlin_tile_size: Optional[int] = None,
- **kwargs):
- self._packed_factor = packed_factor
- self._packed_dim = packed_dim
- self._marlin_tile_size = marlin_tile_size
- super().__init__(**kwargs)
- @property
- def packed_dim(self):
- return self._packed_dim
- @property
- def packed_factor(self):
- return self._packed_factor
- @property
- def marlin_tile_size(self):
- return self._marlin_tile_size
- def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
- return _adjust_shard_indexes_for_packing(
- shard_size=shard_size,
- shard_offset=shard_offset,
- packed_factor=self.packed_factor,
- marlin_tile_size=self.marlin_tile_size)
- def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
- marlin_tile_size):
- return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
- def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
- marlin_tile_size):
- shard_size = shard_size // packed_factor
- shard_offset = shard_offset // packed_factor
- if marlin_tile_size is not None:
- return _adjust_shard_indexes_for_marlin(
- shard_size=shard_size,
- shard_offset=shard_offset,
- marlin_tile_size=marlin_tile_size)
- return shard_size, shard_offset
- # Qweights in HQQ need to be reshaped such that the shape of the stored tensors
- # is the actual shape used in weight multiplication. This is needed to correctly
- # repack to Marlin. We also store shard size and offsets in order to be able to
- # correctly unpack (shard by shard) from 4-bit to 8-bit.
- class HQQQweightParameter(PackedAphroditeParameter):
-
- def __init__(self, packed_factor: int, packed_dim: int, **kwargs):
- super().__init__(packed_factor, packed_dim, None, **kwargs)
- self.shard_offsets: List[Tuple[int, int]] = []
- self.pack_factor = packed_factor
- def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
- shard_offset = kwargs.get("shard_offset")
- shard_size = kwargs.get("shard_size")
- shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
- shard_offset=shard_offset, shard_size=shard_size)
- self.shard_offsets.append((shard_offset, shard_size))
- loaded_weight = loaded_weight.reshape(-1, self.shape[1])
- super().load_merged_column_weight(loaded_weight, **kwargs)
- def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
- self.shard_offsets.append((0, self.shape[self.output_dim]))
- loaded_weight = loaded_weight.reshape(-1, self.shape[1])
- super().load_row_parallel_weight(loaded_weight)
- def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
- shard_offset = kwargs.get("shard_offset")
- shard_size = kwargs.get("shard_size")
- shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
- shard_offset=shard_offset, shard_size=shard_size)
- self.shard_offsets.append((shard_offset, shard_size))
- loaded_weight = loaded_weight.reshape(-1, self.shape[1])
- super().load_qkv_weight(loaded_weight, **kwargs)
- # Zero points and scales in HQQ must also be reshaped to their actual shapes.
- class HQQZeroScaleParameter(GroupQuantScaleParameter):
- def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
- loaded_weight = loaded_weight.reshape(-1, self.shape[1])
- super().load_merged_column_weight(loaded_weight, **kwargs)
- def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
- loaded_weight = loaded_weight.reshape(-1, self.shape[1])
- super().load_row_parallel_weight(loaded_weight)
- def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
- loaded_weight = loaded_weight.reshape(-1, self.shape[1])
- super().load_qkv_weight(loaded_weight, **kwargs)
|