eetq.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. from typing import Any, Dict, List, Optional
  2. from contextlib import suppress
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
  6. from aphrodite.quantization.base_config import \
  7. QuantizationConfig
  8. HAS_EETQ = False
  9. with suppress(ImportError):
  10. from eetq import w8_a16_gemm
  11. HAS_EETQ = True
  12. class EETQConfig(QuantizationConfig):
  13. """Config class for eetq.
  14. https://github.com/NetEase-FuXi/EETQ/tree/main
  15. """
  16. def __init__(
  17. self,
  18. weight_bits: int,
  19. zero_point: bool,
  20. ) -> None:
  21. self.weight_bits = weight_bits
  22. self.zero_point = zero_point
  23. if self.weight_bits != 8:
  24. raise ValueError(
  25. "Currently, only 8-bit weight quantization is supported for "
  26. f"EETQ, but got {self.weight_bits} bits.")
  27. def __repr__(self) -> str:
  28. return (f"EETQConfig(weight_bits={self.weight_bits}, "
  29. f"zero_point={self.zero_point})")
  30. def get_name(self) -> str:
  31. return "eetq"
  32. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  33. return [torch.half]
  34. def get_min_capability(self) -> int:
  35. # The EETQ kernel only supports Turing or newer GPUs.
  36. return 70
  37. @staticmethod
  38. def get_config_filenames() -> List[str]:
  39. return [
  40. "quant_config.json",
  41. "quantize_config.json",
  42. ]
  43. @classmethod
  44. def from_config(cls, config: Dict[str, Any]) -> "EETQConfig":
  45. weight_bits = cls.get_from_keys(config, ["bits"])
  46. zero_point = cls.get_from_keys(config, ["zero_point"])
  47. return cls(weight_bits, zero_point)
  48. def get_linear_method(self) -> "EETQLinearMethod":
  49. return EETQLinearMethod(self)
  50. def get_scaled_act_names(self) -> List[str]:
  51. return []
  52. def merge_weight(self) -> bool:
  53. return True
  54. def quant_vocab(self) -> List[bool]:
  55. return [False, False]
  56. def support_fused_moe(self) -> bool:
  57. return False
  58. def rope_style(self) -> Optional[bool]:
  59. return None
  60. class EETQLinearMethod(LinearMethodBase):
  61. """Linear method for EETQ.
  62. Args:
  63. quant_config: The EETQ quantization config.
  64. """
  65. def __init__(self, quant_config: EETQConfig):
  66. self.quant_config = quant_config
  67. def create_weights(self, input_size_per_partition: int,
  68. output_partition_sizes: List[int], input_size: int,
  69. output_size: int,
  70. params_dtype: torch.dtype) -> Dict[str, Any]:
  71. output_size_per_partition = sum(output_partition_sizes)
  72. qweight = Parameter(torch.empty(input_size_per_partition,
  73. output_size_per_partition,
  74. dtype=torch.int8),
  75. requires_grad=False)
  76. weight_scales = Parameter(torch.empty(output_size_per_partition,
  77. dtype=torch.float16),
  78. requires_grad=False)
  79. set_weight_attrs(qweight, {
  80. "input_dim": 0,
  81. "output_dim": 1,
  82. })
  83. set_weight_attrs(weight_scales, {"input_dim": 0, "output_dim": 0})
  84. return {"qweight": qweight, "weight_scales": weight_scales}
  85. def apply_weights(self,
  86. weights: Dict[str, Any],
  87. x: torch.Tensor,
  88. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  89. qweight = weights["qweight"].data
  90. weight_scales = weights["weight_scales"].data
  91. if HAS_EETQ:
  92. output = w8_a16_gemm(x, qweight, weight_scales)
  93. else:
  94. raise ImportError("You have not installed EETQ. Please refer to "
  95. "https://github.com/NetEase-FuXi/EETQ")
  96. return output
  97. def apply_moe_weights(self, w1: Dict[str,
  98. torch.Tensor], w2: Dict[str,
  99. torch.Tensor],
  100. x: torch.Tensor, gating_output: torch.Tensor,
  101. topk: int, renormalize: bool) -> torch.Tensor:
  102. raise NotImplementedError