123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- from typing import Any, Dict, List, Optional
- import torch
- import torch.nn as nn
- from loguru import logger
- from aphrodite import _custom_ops as ops
- from aphrodite.distributed import get_tensor_model_parallel_rank
- from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
- from aphrodite.modeling.utils import set_weight_attrs
- from aphrodite.quantization.base_config import QuantizationConfig
- from aphrodite.quantization.utils.fp6_utils import (_SPLIT_K_MAP,
- from_scaled_tc_fpx,
- to_scaled_tc_fpx)
- class QuantLLMFPConfig(QuantizationConfig):
- """Config for QuantLLM FP quantizer. It supports fp2, fp3, fp4,
- fp5, fp6, fp7.
-
- Reference: https://arxiv.org/abs/2401.14112
-
- Args:
- weight_bits: the target quantization bits, should be one of
- 2, 3, 4, 5, 6, 7.
- """
- def __init__(
- self,
- weight_bits: int = 6,
- exp_bits: int = 2,
- ) -> None:
- self.weight_bits = weight_bits
- self.exponent_bits = exp_bits
- self.mantissa_bits = weight_bits - self.exponent_bits - 1
- self.valid_types = [torch.float16]
- if self.weight_bits not in [2, 3, 4, 5, 6, 7]:
- raise ValueError(
- "Currently, only 4-bit, 5-bit, 6-bit, and 7-bit "
- "quantization are "
- f"supported for QuantLLM FP quantizaiton, but got "
- f"{self.weight_bits} bits.")
-
- if get_tensor_model_parallel_rank() == 0:
- logger.info(f"Loading model in FP{self.weight_bits}_E"
- f"{self.exponent_bits}M{self.mantissa_bits} format.")
- def __repr__(self) -> str:
- return (f"QuantLLMFPConfig(weight_bits={self.weight_bits}), "
- f"exponent_bits={self.exponent_bits}")
- @classmethod
- def get_name(cls) -> str:
- return "QuantLLMFP"
- @classmethod
- def from_config(cls, config: Dict[str, Any]) -> "QuantLLMFPConfig":
- weight_bits = cls.get_from_keys(config, ["bits"])
- exp_bits = cls.get_from_keys(config, ["exp_bits"])
- return cls(weight_bits=weight_bits, exp_bits=exp_bits)
- def get_linear_method(self) -> "QuantLLMFPLinearMethod":
- return QuantLLMFPLinearMethod(self)
- def get_scaled_act_names(self) -> List[str]:
- return []
- @classmethod
- def get_supported_act_dtypes(cls) -> List[torch.dtype]:
- return [torch.half]
- @classmethod
- # Need to figure it out
- def get_min_capability(cls) -> int:
- return 80
- @staticmethod
- def get_config_filenames() -> List[str]:
- return [
- "quant_config.json",
- "quantize_config.json",
- ]
- def get_quant_method(
- self,
- layer: torch.nn.Module,
- prefix: str) -> Optional["QuantLLMFPLinearMethod"]:
- if isinstance(layer, LinearBase):
- return QuantLLMFPLinearMethod(self)
- return None
- class QuantLLMFPLinearMethod(LinearMethodBase):
- """Linear method for QuantLLMFP quantizer.
- Args:
- quant_config: the QuantLLMFP quantization config.
- """
- def __init__(self, quant_config: QuantLLMFPConfig):
- self.quant_config = quant_config
- self.weight = None
- def create_weights(self,
- layer: torch.nn.Module,
- input_size_per_partition: int,
- output_partition_sizes: List[int],
- input_size: int,
- output_size: int,
- params_dtype: torch.dtype,
- weight_loader=None,
- **extra_weight_attrs):
- del output_size
- del input_size
- output_size_per_partition = sum(output_partition_sizes)
- weight = QuantLLMFPParameter(
- torch.Size((output_size_per_partition, input_size_per_partition)),
- params_dtype=params_dtype,
- quant_config=self.quant_config,
- )
- set_weight_attrs(weight, {
- "input_dim": 1,
- "output_dim": 0,
- })
- layer.register_parameter("weight", weight)
- def quant_weight_loader(param, loaded_weight, *args, **kwargs):
- # Calls the original weight loader (if any), quantizes the result,
- # and then loads the quantized parameter.
- if weight_loader is not None:
- orig_param_data = param.data
- param.data = param.quant_llmdequantize()
- weight_loader(param, loaded_weight, *args, **kwargs)
- param.data, loaded_weight = orig_param_data, param.data
- param.quant_llmquantize_(loaded_weight.cuda())
- extra_weight_attrs["weight_loader"] = quant_weight_loader
- set_weight_attrs(weight, extra_weight_attrs)
- def apply(self,
- layer,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- weight = layer.weight
- weights = weight.data
- scales = weight.scales
- out_dim, in_dim = weights.shape
- bsize = x.shape[0]
- splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(
- out_dim, 1) if bsize <= 768 else 1
- if bias is None:
- return ops.fp_eXmY_linear_forward_cuda(
- self.quant_config.exponent_bits,
- self.quant_config.mantissa_bits,
- x, weights, scales, splitK)
- else:
- return ops.fp_eXmY_linear_forward_cuda(
- self.quant_config.exponent_bits,
- self.quant_config.mantissa_bits,
- x, weights, scales, splitK) + bias
- class QuantLLMFPParameter(nn.Parameter):
- """
- QuantLLMFP quantized parameter class that implements fp5/fp6/fp7
- quantization. Weights are stored in quantized form on
- GPUs, and can be directly applied to float16 activations.
- """
- def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
- quant_config: QuantLLMFPConfig):
- data = torch.empty(torch.Size((orig_shape[0],
- orig_shape[1] * quant_config.weight_bits // 8)),
- dtype=torch.uint8)
- self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
- self.scales = torch.empty(orig_shape[0],
- dtype=torch.float16)
- self.quant_config = quant_config
- self.orig_shape = orig_shape
- return self
- def quant_llmquantize_(self, tensor: torch.Tensor):
- assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
- data, scales = to_scaled_tc_fpx(
- tensor.data, self.quant_config.exponent_bits,
- self.quant_config.mantissa_bits)
- self.data.copy_(data)
- self.scales.copy_(scales)
- def quant_llmdequantize(self, output_dtype=None):
- output_dtype = output_dtype or torch.get_default_dtype()
- return from_scaled_tc_fpx(self.data, self.quant_config.exponent_bits,
- self.quant_config.mantissa_bits, self.scales
- ).to(output_dtype)
|