123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373 |
- # Supports AQLM compression, see https://github.com/Vahe1994/AQLM
- # and https://arxiv.org/pdf/2401.06118.pdf
- import math
- from typing import Any, Dict, List, Optional
- import torch
- import torch.nn.functional as F
- from torch.nn.parameter import Parameter
- from aphrodite import _custom_ops as ops
- from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
- from aphrodite.modeling.utils import set_weight_attrs
- from aphrodite.quantization.base_config import QuantizationConfig
- def get_int_dtype(nbits: int) -> torch.dtype:
- if nbits <= 8:
- return torch.int8
- if nbits <= 16:
- return torch.int16
- if nbits <= 32:
- return torch.int32
- if nbits <= 64:
- return torch.int64
- raise ValueError(f"No dtype available for {nbits}-bit codebooks")
- @torch.inference_mode()
- def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
- return data.to(torch.int64) % (2**nbits)
- def dequantize_weight(codes: torch.Tensor,
- codebooks: torch.Tensor,
- scales: Optional[torch.Tensor] = None) -> torch.Tensor:
- """
- Decode float weights from quantization codes. Differentiable.
- :param codes: tensor of integer quantization codes, shape
- [*dims, num_out_groups, num_in_groups, num_codebooks]
- :param codebooks: tensor of vectors for each quantization code,
- [num_codebooks, codebook_size, out_group_size, in_group_size]
- :param scales: weight will be multiplied by this factor, must be
- broadcastble with
- [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
- :return: reconstructed weight tensor of shape
- [*dims, num_in_groups*group_size]
- """
- num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
- num_codebooks, codebook_size, out_group_size, in_group_size = \
- codebooks.shape
- out_features = num_out_groups * out_group_size
- in_features = num_in_groups * in_group_size
- codebook_offsets = torch.arange(
- 0, num_codebooks * codebook_size, codebook_size,
- device=codes.device) # shape: [num_codebooks]
- reconstructed_weight_flat = F.embedding_bag(
- codes.flatten(0, -2) + codebook_offsets,
- codebooks.flatten(0, 1).flatten(-2, -1),
- mode="sum"
- ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size
- # * in_group_size]
- reconstructed_weight_groupwise = reconstructed_weight_flat.view(
- list(codes.shape[:-3]) +
- [num_out_groups, num_in_groups, out_group_size, in_group_size])
- if scales is not None:
- reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
- scales)
- return reconstructed_weight_groupwise.swapaxes(
- -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
- def dequantize_gemm(
- input: torch.Tensor, # [..., in_features]
- codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
- codebooks: torch.
- Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
- scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
- bias: Optional[torch.Tensor],
- ) -> torch.Tensor:
- dequantized_weight = dequantize_weight(
- unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
- codebooks,
- scales,
- )
- return F.linear(input, dequantized_weight, bias)
- # Generic dequantization, slow but flexible.
- def generic_dequantize_gemm(
- input: torch.Tensor, # [..., in_features]
- codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
- codebooks: torch.
- Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
- scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
- output_partition_sizes: List[int],
- bias: Optional[torch.Tensor],
- ) -> torch.Tensor:
- output_shape = input.shape[:-1] + (scales.shape[0], )
- output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
- num_outputs = len(output_partition_sizes)
- # break the inputs and codebooks apart then combine the outputs.
- # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
- # multiply at the end.
- num_codebooks = codebooks.shape[0] // num_outputs
- assert (scales.shape[0] == codes.shape[0])
- assert (sum(output_partition_sizes) == scales.shape[0])
- output_offset = 0
- codebooks_offset = 0
- for output_size in output_partition_sizes:
- shard_output = dequantize_gemm(
- input, codes.narrow(0, output_offset, output_size),
- codebooks.narrow(0, codebooks_offset, num_codebooks),
- scales.narrow(0, output_offset, output_size), None
- if bias is None else bias.narrow(0, output_offset, output_size))
- output_slice = output.narrow(-1, output_offset, output_size)
- assert (output_slice.shape == shard_output.shape)
- output_slice.copy_(shard_output)
- output_offset += output_size
- codebooks_offset += num_codebooks
- return output
- # Optimized dequnantize/decompression kernels, supports 1x16 and 2x8
- # at 6 and 9 times faster than the generic version above, respectively.
- def optimized_dequantize_gemm(
- input: torch.Tensor, # [..., in_features]
- codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
- codebooks: torch.
- Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
- scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
- output_partition_sizes: List[int],
- bias: Optional[torch.Tensor],
- ) -> torch.Tensor:
- weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
- if bias is None:
- # scaling the output is fastest, so we do that when possible.
- output = F.linear(input, weights, bias)
- orig_shape = output.shape
- flattened_output = output.view(-1, output.size(-1))
- f_scales = scales.view(-1, scales.shape[0])
- b_scales = f_scales.expand(flattened_output.shape[0], -1)
- flattened_output *= b_scales
- return output.view(orig_shape)
- else:
- b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
- -1, weights.shape[1])
- weights *= b_scales
- return F.linear(input, weights, bias)
- class AQLMConfig(QuantizationConfig):
- """Config class for AQLM.
- Reference: https://github.com/Vahe1994/AQLM
- """
- def __init__(
- self,
- in_group_size: int,
- nbits_per_codebook: int,
- num_codebooks: int,
- out_group_size: int,
- ) -> None:
- self.in_group_size = in_group_size
- self.nbits_per_codebook = nbits_per_codebook
- self.num_codebooks = num_codebooks
- self.out_group_size = out_group_size
- # out_group_size > 1 is untested, and probably won't work as-is.
- assert (self.out_group_size == 1)
- self.pack_factor = (self.in_group_size * self.out_group_size)
- def __repr__(self) -> str:
- return (f"AQLMConfig(in_group_size={self.in_group_size}, "
- f"nbits_per_codebook={self.nbits_per_codebook}, "
- f"num_codebooks={self.num_codebooks}, "
- f"out_group_size={self.out_group_size})")
- @classmethod
- def get_name(cls) -> str:
- return "aqlm"
- @classmethod
- def get_supported_act_dtypes(cls) -> List[torch.dtype]:
- return [torch.half]
- @classmethod
- def get_min_capability(cls) -> int:
- return 60
- @classmethod
- def get_config_filenames(cls) -> List[str]:
- return [] # no extra configs.
- @classmethod
- def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
- in_group_size = cls.get_from_keys(config, ["in_group_size"])
- nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
- num_code_books = cls.get_from_keys(config, ["num_codebooks"])
- out_group_size = cls.get_from_keys(config, ["out_group_size"])
- return cls(in_group_size, nbits_per_codebook, num_code_books,
- out_group_size)
- def get_quant_method(self, layer: torch.nn.Module,
- prefix: str) -> Optional["AQLMLinearMethod"]:
- if isinstance(layer, LinearBase):
- return AQLMLinearMethod(self)
- return None
- def get_scaled_act_names(self) -> List[str]:
- return []
- class AQLMLinearMethod(LinearMethodBase):
- """Linear method for AQLM.
- Args:
- quant_config: The AQLM quantization config.
- """
- def __init__(self, quant_config: AQLMConfig):
- self.quant_config = quant_config
- def create_weights(self, layer: torch.nn.Module,
- input_size_per_partition: int,
- output_partition_sizes: List[int], input_size: int,
- output_size: int, params_dtype: torch.dtype,
- **extra_weight_attrs):
- del output_size # Unused.
- del input_size # Unused.
- if params_dtype != torch.half:
- raise ValueError("Only half is currently supported by aqlm")
- if input_size_per_partition % self.quant_config.in_group_size != 0:
- raise ValueError(
- "The input size is not aligned with the quantized "
- "weight shape. This can be caused by too large "
- "tensor parallel size.")
- output_size_per_partition = sum(output_partition_sizes)
- if output_size_per_partition % self.quant_config.out_group_size != 0:
- raise ValueError(
- "The output size is not aligned with the quantized "
- "weight shape. This can be caused by too large "
- "tensor parallel size.")
- codes = Parameter(
- torch.empty(
- # There could actually be two pack factors, one along input and
- # one along output, but we don't currently support
- # out_group_size, and only the one along output needs to be
- # marked with "packed_dim" in order for QKVLinear to work.
- output_size_per_partition,
- input_size_per_partition // self.quant_config.pack_factor,
- self.quant_config.num_codebooks,
- dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
- ),
- requires_grad=False,
- )
- set_weight_attrs(
- codes,
- {
- "input_dim": 1,
- "output_dim": 0,
- "packed_dim": 1,
- "pack_factor": self.quant_config.pack_factor,
- },
- )
- codebooks = Parameter(
- torch.empty(
- self.quant_config.num_codebooks * len(output_partition_sizes),
- 2**self.quant_config.nbits_per_codebook,
- self.quant_config.out_group_size,
- self.quant_config.in_group_size,
- dtype=params_dtype,
- ),
- requires_grad=False,
- )
- set_weight_attrs(
- codebooks,
- {
- # metadata indicates fixed size concatenated along dim 0
- "is_metadata": True,
- "output_partition_sizes": output_partition_sizes
- },
- )
- scales = Parameter(
- torch.empty(
- (
- output_size_per_partition //
- self.quant_config.out_group_size,
- 1,
- 1,
- 1,
- ),
- dtype=params_dtype,
- ),
- requires_grad=False,
- )
- set_weight_attrs(
- scales,
- {
- "output_dim": 0,
- "packed_dim": 0,
- "pack_factor": self.quant_config.out_group_size
- },
- )
- layer.register_parameter("codes", codes)
- set_weight_attrs(codes, extra_weight_attrs)
- layer.register_parameter("codebooks", codebooks)
- set_weight_attrs(codebooks, extra_weight_attrs)
- layer.register_parameter("scales", scales)
- set_weight_attrs(scales, extra_weight_attrs)
- def apply(
- self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- codebooks = layer.codebooks
- codes = layer.codes
- scales = layer.scales
- output_partition_sizes = getattr(codebooks, "output_partition_sizes",
- [])
- nbooks = codes.shape[2]
- ingroups = codebooks.shape[3]
- outgroups = codebooks.shape[2]
- bits = codebooks.shape[1]
- # We support these formats with dedicated gemm and decompression
- # kernels.
- if ingroups == 8 and outgroups == 1 and (
- (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)):
- # thresholds determined by timings on an A6000, one GPU
- use_gemv = math.prod(x.shape[:-1]) <= 6
- return ops.aqlm_gemm(
- x,
- codes,
- codebooks,
- scales,
- output_partition_sizes,
- bias,
- ) if use_gemv else optimized_dequantize_gemm(
- x,
- codes,
- codebooks,
- scales,
- output_partition_sizes,
- bias,
- )
- # fall back all unoptimized formats
- return generic_dequantize_gemm(
- x,
- codes,
- codebooks,
- scales,
- output_partition_sizes,
- bias,
- )
|