123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322 |
- from typing import Any, Dict, List, Optional
- import torch
- from loguru import logger
- from torch.nn.parameter import Parameter
- from aphrodite import _custom_ops as ops
- from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
- set_weight_attrs)
- from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
- from aphrodite.quantization.base_config import QuantizationConfig
- from aphrodite.quantization.utils.marlin_utils import (
- apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full,
- marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
- marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
- verify_gptq_marlin_supported, verify_marlin_supports_shape)
- class GPTQMarlinConfig(QuantizationConfig):
- """Config class for GPTQ Marlin"""
- def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
- is_sym: bool, lm_head_quantized: bool) -> None:
- if desc_act and group_size == -1:
- # In this case, act_order == True is the same as act_order == False
- # (since we have only one group per output channel)
- desc_act = False
- self.weight_bits = weight_bits
- self.pack_factor = 32 // self.weight_bits # packed into int32
- self.group_size = group_size
- self.desc_act = desc_act
- self.is_sym = is_sym
- self.lm_head_quantized = lm_head_quantized
- # Verify supported on platform.
- verify_gptq_marlin_supported(num_bits=self.weight_bits,
- group_size=self.group_size,
- is_sym=self.is_sym)
- def __repr__(self) -> str:
- return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
- f"group_size={self.group_size}, "
- f"desc_act={self.desc_act}, "
- f"lm_head_quantized={self.lm_head_quantized})")
- @classmethod
- def get_name(cls) -> str:
- return "gptq_marlin"
- @classmethod
- def get_supported_act_dtypes(cls) -> List[torch.dtype]:
- return [torch.half, torch.bfloat16]
- @classmethod
- def get_min_capability(cls) -> int:
- return 80
- @classmethod
- def get_config_filenames(cls) -> List[str]:
- return ["quantize_config.json"]
- @classmethod
- def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
- weight_bits = cls.get_from_keys(config, ["bits"])
- group_size = cls.get_from_keys(config, ["group_size"])
- desc_act = cls.get_from_keys(config, ["desc_act"])
- is_sym = cls.get_from_keys(config, ["sym"])
- lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
- default=False)
- return cls(weight_bits, group_size, desc_act, is_sym,
- lm_head_quantized)
- @classmethod
- def override_quantization_method(cls, hf_quant_cfg,
- user_quant) -> Optional[str]:
- can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
- is_valid_user_quant = (user_quant is None or user_quant == "marlin")
- if can_convert and is_valid_user_quant:
- msg = ("The model is convertible to {} during runtime."
- " Using {} kernel.".format(cls.get_name(), cls.get_name()))
- logger.info(msg)
- return cls.get_name()
- if can_convert and user_quant == "gptq":
- logger.info("Detected that the model can run with gptq_marlin"
- ", however you specified quantization=gptq explicitly,"
- " so forcing gptq. Use quantization=gptq_marlin for"
- " faster inference")
- return None
- def get_quant_method(self, layer: torch.nn.Module,
- prefix: str) -> Optional["GPTQMarlinLinearMethod"]:
- if (isinstance(layer, LinearBase) or
- (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
- return GPTQMarlinLinearMethod(self)
- return None
- def get_scaled_act_names(self) -> List[str]:
- return []
- @classmethod
- def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
- # Extract data from quant config.
- quant_method = quant_config.get("quant_method", "").lower()
- num_bits = quant_config.get("bits", None)
- group_size = quant_config.get("group_size", None)
- sym = quant_config.get("sym", None)
- desc_act = quant_config.get("desc_act", None)
- if quant_method != "gptq":
- return False
- # If we cannot find the info needed in the config, cannot convert.
- if (num_bits is None or group_size is None or sym is None
- or desc_act is None):
- return False
- return check_gptq_marlin_supported(
- num_bits=num_bits,
- group_size=group_size,
- is_sym=sym,
- min_capability=cls.get_min_capability())
- class GPTQMarlinLinearMethod(LinearMethodBase):
- """Linear method for GPTQ Marlin.
- Args:
- quant_config: The GPTQ Marlin quantization config.
- """
- def __init__(self, quant_config: GPTQMarlinConfig) -> None:
- self.quant_config = quant_config
- def create_weights(
- self,
- layer: torch.nn.Module,
- input_size_per_partition: int,
- output_partition_sizes: List[int],
- input_size: int,
- output_size: int,
- params_dtype: torch.dtype,
- **extra_weight_attrs,
- ) -> None:
- del output_size
- output_size_per_partition = sum(output_partition_sizes)
- is_row_parallel = input_size != input_size_per_partition
- # Normalize group_size
- if self.quant_config.group_size != -1:
- group_size = self.quant_config.group_size
- else:
- group_size = input_size
- verify_marlin_supports_shape(
- output_size_per_partition=output_size_per_partition,
- input_size_per_partition=input_size_per_partition,
- input_size=input_size,
- group_size=group_size)
- # Determine sharding
- if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
- self.quant_config.group_size,
- is_row_parallel):
- # By setting scale_dim == None, weight_loader will
- # repeat the scales on each GPU in TP>1 case.
- scales_and_zp_input_dim = None
- scales_and_zp_size = input_size // group_size
- else:
- # By setting scale_dim == 0, weight_loader will
- # shard the scales in TP>1 case.
- scales_and_zp_input_dim = 0
- scales_and_zp_size = input_size_per_partition // group_size
- # Quantized weights
- qweight = Parameter(
- torch.empty(
- input_size_per_partition // self.quant_config.pack_factor,
- output_size_per_partition,
- dtype=torch.int32,
- ),
- requires_grad=False,
- )
- set_weight_attrs(
- qweight,
- {
- **extra_weight_attrs,
- "input_dim": 0,
- "output_dim": 1,
- "packed_dim": 0,
- "pack_factor": self.quant_config.pack_factor,
- },
- )
- # Activation order
- g_idx = Parameter(
- torch.empty(
- input_size_per_partition,
- dtype=torch.int32,
- ),
- requires_grad=False,
- )
- # Ignore warning from fused linear layers such as QKVParallelLinear.
- set_weight_attrs(
- g_idx,
- {
- **extra_weight_attrs, "input_dim": 0,
- "ignore_warning": True
- },
- )
- # Scales
- scales = Parameter(
- torch.empty(
- scales_and_zp_size,
- output_size_per_partition,
- dtype=params_dtype,
- ),
- requires_grad=False,
- )
- set_weight_attrs(
- scales,
- {
- **extra_weight_attrs,
- "input_dim": scales_and_zp_input_dim,
- "output_dim": 1,
- },
- )
- # Quantized zero-points
- qzeros = Parameter(
- torch.empty(
- scales_and_zp_size,
- output_size_per_partition // self.quant_config.pack_factor,
- dtype=torch.int32,
- device="meta",
- ),
- requires_grad=False,
- )
- set_weight_attrs(
- qzeros,
- {
- **extra_weight_attrs,
- "input_dim": scales_and_zp_input_dim,
- "output_dim": 1,
- "packed_dim": 1,
- "pack_factor": self.quant_config.pack_factor,
- },
- )
- layer.register_parameter("qweight", qweight)
- layer.register_parameter("g_idx", g_idx)
- layer.register_parameter("scales", scales)
- layer.register_parameter("qzeros", qzeros)
- layer.input_size_per_partition = input_size_per_partition
- layer.output_size_per_partition = output_size_per_partition
- layer.input_size = input_size
- layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
- is_row_parallel)
- # Checkpoints are serialized in AutoGPTQ format, which is different from the
- # marlin format. This function is called after the weights are loaded.
- # Here, we handle the repacking, including the activation reordering case.
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- device = layer.qweight.device
- # Allocate marlin workspace
- layer.workspace = marlin_make_workspace(
- layer.output_size_per_partition, device)
- # Handle sorting for activation reordering if needed.
- if self.quant_config.desc_act:
- g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
- layer.g_idx_sort_indices = g_idx_sort_indices
- replace_tensor(layer, "g_idx", g_idx)
- else:
- layer.g_idx = marlin_make_empty_g_idx(device)
- layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
- # No zero-point
- layer.zp = marlin_make_empty_g_idx(device)
- # Repack weights from autogptq format to marlin format.
- marlin_qweight = ops.gptq_marlin_repack(
- layer.qweight,
- perm=layer.g_idx_sort_indices,
- size_k=layer.input_size_per_partition,
- size_n=layer.output_size_per_partition,
- num_bits=self.quant_config.weight_bits)
- replace_tensor(layer, "qweight", marlin_qweight)
- # Permute scales from autogptq format to marlin format.
- marlin_scales = marlin_permute_scales(
- layer.scales,
- size_k=(layer.input_size if self.quant_config.desc_act else
- layer.input_size_per_partition),
- size_n=layer.output_size_per_partition,
- group_size=self.quant_config.group_size)
- replace_tensor(layer, "scales", marlin_scales)
- def apply(
- self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- return apply_gptq_marlin_linear(
- input=x,
- weight=layer.qweight,
- weight_scale=layer.scales,
- weight_zp=layer.zp,
- g_idx=layer.g_idx,
- g_idx_sort_indices=layer.g_idx_sort_indices,
- workspace=layer.workspace,
- num_bits=self.quant_config.weight_bits,
- output_size_per_partition=layer.output_size_per_partition,
- input_size_per_partition=layer.input_size_per_partition,
- is_k_full=layer.is_k_full,
- bias=bias)
|