123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- from typing import Any, Dict, List, Optional
- from contextlib import suppress
- import torch
- from torch.nn.parameter import Parameter
- from aphrodite.modeling.layers.linear import (LinearMethodBase,
- set_weight_attrs)
- from aphrodite.quantization.base_config import (QuantizationConfig)
- from aphrodite.quantization.quip_utils import (
- get_packed_abs_grid,
- get_hadK,
- matmul_hadUt_cuda,
- matmul_hadU_cuda,
- )
- HAS_QUANTS = False
- with suppress(ImportError):
- from aphrodite._quant_C import quant_ops as ops
- HAS_QUANTS = True
- class QuipConfig(QuantizationConfig):
- """Config class for Quip.
- Reference: https://cornell-relaxml.github.io/quip-sharp/
- """
- def __init__(self, codebook: int, use_rand: bool) -> None:
- if not HAS_QUANTS:
- raise ImportError("Could not find the quantization kernels.")
- self.codebook = codebook
- self.use_rand = use_rand
- if self.codebook != "E8P12":
- raise ValueError("Currently, only E8P12 is supported for "
- f"Quip, but got {self.codebook}.")
- def __repr__(self) -> str:
- return (f"QuipConfig(codebook={self.codebook}, "
- f"rescale_WH={self.rescale_WH})")
- @classmethod
- def get_name(cls) -> str:
- return "quip"
- @classmethod
- def get_supported_act_dtypes(cls) -> List[torch.dtype]:
- return [torch.half]
- @classmethod
- def get_min_capability(cls) -> int:
- return 80
- @classmethod
- def get_config_filenames(cls) -> List[str]:
- return ["quantization_config.json"]
- @classmethod
- def from_config(cls, config: Dict[str, Any]) -> "QuipConfig":
- codebook = cls.get_from_keys(config, ["codebook"])
- use_rand = cls.get_from_keys(config, ["use_rand"])
- return cls(codebook, use_rand)
- def get_linear_method(self) -> "QuipLinearMethod":
- return QuipLinearMethod(self)
- def get_scaled_act_names(self) -> List[str]:
- return []
- def merge_weight(self) -> bool:
- return False
- def quant_vocab(self) -> List[bool]:
- return [False, False]
- def support_fused_moe(self) -> bool:
- return False
- def rope_style(self) -> Optional[bool]:
- return None
- class QuipLinearMethod(LinearMethodBase):
- """Linear method for Quip.
- Args:
- quant_config: The Quip quantization config.
- """
- def __init__(self, quant_config: QuipConfig):
- self.quant_config = quant_config
- self.grid_packed_abs = get_packed_abs_grid().to(device="cuda")
- self.pack = 8
- self.idx_dtype = torch.int16
- def create_weights(
- self,
- input_size_per_partition: int,
- output_partition_sizes: List[int],
- input_size: int,
- output_size: int,
- params_dtype: torch.dtype,
- ) -> Dict[str, Any]:
- output_size_per_partition = sum(output_partition_sizes)
- if (input_size != input_size_per_partition
- or output_size != output_size_per_partition):
- raise ValueError(
- "Currently Quip doesn't support tensor parallel yet")
- had_left, K_left, q_in_features = get_hadK(input_size,
- self.quant_config.use_rand)
- had_right, K_right, q_out_features = get_hadK(
- output_size, self.quant_config.use_rand)
- weights = {
- "K_left": K_left,
- "K_right": K_right,
- "q_in_features": q_in_features,
- "q_out_features": q_out_features,
- }
- if had_left is not None:
- weights["had_left"] = Parameter(
- had_left.to(dtype=params_dtype, device="cuda"),
- requires_grad=False,
- )
- set_weight_attrs(weights["had_left"], {"ignore_warning": True})
- if had_right is not None:
- weights["had_right"] = Parameter(
- had_right.to(dtype=params_dtype, device="cuda"),
- requires_grad=False,
- )
- set_weight_attrs(weights["had_right"], {"ignore_warning": True})
- Qidxs = Parameter(
- torch.empty(q_out_features,
- q_in_features // self.pack,
- device="cuda",
- dtype=self.idx_dtype),
- requires_grad=False,
- )
- set_weight_attrs(Qidxs, {"ignore_warning": True})
- Wscale = Parameter(
- torch.ones((), dtype=torch.float, device="cuda"),
- requires_grad=False,
- )
- set_weight_attrs(Wscale, {"ignore_warning": True})
- SU = Parameter(
- torch.ones(
- input_size,
- device="cuda",
- dtype=params_dtype,
- ),
- requires_grad=False,
- )
- set_weight_attrs(SU, {"ignore_warning": True})
- SV = Parameter(
- torch.ones(
- output_size,
- device="cuda",
- dtype=params_dtype,
- ),
- requires_grad=False,
- )
- set_weight_attrs(SV, {"ignore_warning": True})
- weights.update({
- "Qidxs": Qidxs,
- "Wscale": Wscale,
- "SU": SU,
- "SV": SV,
- })
- return weights
- def apply_weights(self,
- weights: Dict[str, Any],
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- # First run
- if isinstance(weights["Wscale"], torch.Tensor):
- weights["Wscale"] = weights["Wscale"].item()
- if "SU" in weights and torch.all(weights["SU"] > 0):
- del weights["SU"]
- if "SV" in weights and torch.all(weights["SV"] > 0):
- del weights["SV"]
- reshaped_x = x.reshape(-1, x.shape[-1])
- out_dim = weights["Qidxs"].shape[0]
- if "SU" in weights:
- reshaped_x = reshaped_x * weights["SU"]
- reshaped_x = matmul_hadUt_cuda(reshaped_x,
- weights.get("had_left",
- None), weights["K_left"],
- weights["q_in_features"],
- weights["Wscale"])
- m, n = weights["Qidxs"].shape
- if reshaped_x.size(0) < 32:
- out = ops.quip_gemv(reshaped_x, weights["Qidxs"],
- self.grid_packed_abs)
- else:
- W_decompressed = torch.empty(m,
- n * 8,
- dtype=torch.float16,
- device=x.device)
- ops.quip_decompress(weights["Qidxs"], self.grid_packed_abs,
- W_decompressed)
- out = reshaped_x @ W_decompressed.T
- out = matmul_hadU_cuda(out, weights.get("had_right",
- None), weights["K_right"],
- weights["q_out_features"])[..., :out_dim]
- if "SV" in weights:
- out = out * weights["SV"]
- out = out.view(*x.shape[:-1], out.shape[-1])
- out = out + bias if bias is not None else out
- return out
- def apply_moe_weights(self, w1: Dict[str,
- torch.Tensor], w2: Dict[str,
- torch.Tensor],
- x: torch.Tensor, gating_output: torch.Tensor,
- topk: int, renormalize: bool) -> torch.Tensor:
- raise NotImplementedError
|