123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- from typing import Any, Dict, List, Optional
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
- from aphrodite.modeling.utils import set_weight_attrs
- from aphrodite.quantization.base_config import QuantizationConfig
- class DeepSpeedFPConfig(QuantizationConfig):
- """Config for DeepSpeed FP quantizer. It supports fp6 and fp8.
-
- Args:
- weight_bits: the target quantization bits, 6 or 8.
- group_size: group size for quantizaiton, default to 128.
- """
- def __init__(
- self,
- weight_bits: int = 8,
- group_size: int = 512,
- ) -> None:
- self.weight_bits = weight_bits
- self.group_size = group_size
- self.valid_types = [torch.bfloat16, torch.float16]
- if self.weight_bits not in (4, 6, 8, 12):
- raise ValueError(
- "Currently, only 4-bit, 6-bit, 8-bit, and 12-bit weight"
- " quantization are "
- f"supported for DeepSpeed FP quantizaiton, but got "
- f"{self.weight_bits} bits.")
- def __repr__(self) -> str:
- return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), "
- f"group_size={self.group_size}")
- @classmethod
- def get_name(cls) -> str:
- return "DeepSpeedFP"
- @classmethod
- def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
- weight_bits = cls.get_from_keys(config, ["bits"])
- group_size = cls.get_from_keys(config, ["group_size"])
- return cls(weight_bits=weight_bits, group_size=group_size)
- def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
- return DeepSpeedFPLinearMethod(self)
- def get_scaled_act_names(self) -> List[str]:
- return []
- @classmethod
- def get_supported_act_dtypes(cls) -> List[torch.dtype]:
- return [torch.half, torch.bfloat16]
- @classmethod
- # Need to figure it out
- def get_min_capability(cls) -> int:
- return 60
- @staticmethod
- def get_config_filenames() -> List[str]:
- return [
- "quant_config.json",
- "quantize_config.json",
- ]
- def get_quant_method(self, layer: torch.nn.Module,
- prefix: str) -> Optional["DeepSpeedFPLinearMethod"]:
- if isinstance(layer, LinearBase):
- return DeepSpeedFPLinearMethod(self)
- return None
- class DeepSpeedFPLinearMethod(LinearMethodBase):
- """Linear method for DeepSpeedFP quantizer.
- Args:
- quant_config: the DeepSpeedFP quantization config.
- """
- def __init__(self, quant_config: DeepSpeedFPConfig):
- self.quant_config = quant_config
- self.weight = None
- def create_weights(self,
- layer: torch.nn.Module,
- input_size_per_partition: int,
- output_partition_sizes: List[int],
- input_size: int,
- output_size: int,
- params_dtype: torch.dtype,
- weight_loader=None,
- **extra_weight_attrs):
- del output_size
- del input_size
- output_size_per_partition = sum(output_partition_sizes)
- weight = DeepSpeedFPParameter(
- torch.Size((output_size_per_partition, input_size_per_partition)),
- params_dtype=params_dtype,
- quant_config=self.quant_config,
- )
- set_weight_attrs(weight, {
- "input_dim": 1,
- "output_dim": 0,
- })
- layer.register_parameter("weight", weight)
- def quant_weight_loader(param, loaded_weight, *args, **kwargs):
- # Calls the original weight loader (if any), quantizes the result,
- # and then loads the quantized parameter.
- if weight_loader is not None:
- orig_param_data = param.data
- param.data = param.ds_dequantize()
- weight_loader(param, loaded_weight, *args, **kwargs)
- param.data, loaded_weight = orig_param_data, param.data
- param.ds_quantize_(loaded_weight.cuda())
- extra_weight_attrs["weight_loader"] = quant_weight_loader
- set_weight_attrs(weight, extra_weight_attrs)
- def apply(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- weight = layer.weight
- y = weight.ds_dequantize()
- return F.linear(x, y, bias)
- class DeepSpeedFPParameter(nn.Parameter):
- """
- DeepSpeedFP quantized parameter class that implements fp8/fp6
- quantization deepspeed. Weights are stored in quantized form on
- GPUs, and can be dequantized on-the-fly when needed by the model.
- """
- def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
- quant_config: DeepSpeedFPConfig):
- try:
- import deepspeed
- if deepspeed.__version__ < "0.14.2":
- raise ImportError("deepspeed version is wrong. Please "
- "install deepspeed>=0.14.2.")
- from deepspeed.ops.fp_quantizer import FP_Quantize
- except ImportError as err:
- raise ImportError("Please install deepspeed>=0.14.2 via "
- "`pip install deepspeed>=0.14.2` to use "
- "deepspeedfp quantizer.") from err
- data = torch.empty((
- orig_shape.numel() // quant_config.group_size,
- quant_config.group_size * quant_config.weight_bits // 8 + 4,
- ),
- dtype=torch.int8)
- self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
- self.orig_shape = orig_shape
- self.quant_config = quant_config
- self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size)
- self.fp_quantizer.orig_shape = orig_shape
- self.fp_quantizer.orig_dtype = params_dtype
- return self
- def ds_quantize_(self, tensor: torch.Tensor):
- assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
- return self.data.copy_(
- self.fp_quantizer.quantize(
- tensor.data,
- q_bits=self.quant_config.weight_bits,
- ))
- def ds_dequantize(self, fp_out=None) -> torch.Tensor:
- """
- Return a tensor containing the dequantized weights of this parameter.
- """
- assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
- return self.fp_quantizer.dequantize(
- self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits)
- def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor:
- """
- Return a tensor where only the weights at `indices` are dequantized
- (to save HBM -> SRAM bandwidth).
- """
- assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
- return self.fp_quantizer.selective_dequantize(
- self.data,
- indices,
- fp_out=fp_out,
- q_bits=self.quant_config.weight_bits)
|