1
0

eetq.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.modeling.utils import set_weight_attrs
  7. from aphrodite.quantization.base_config import QuantizationConfig
  8. HAS_EETQ = False
  9. with suppress(ImportError):
  10. from eetq import w8_a16_gemm
  11. HAS_EETQ = True
  12. class EETQConfig(QuantizationConfig):
  13. """Config class for eetq.
  14. https://github.com/NetEase-FuXi/EETQ/tree/main
  15. """
  16. def __init__(
  17. self,
  18. weight_bits: int,
  19. zero_point: bool,
  20. ) -> None:
  21. self.weight_bits = weight_bits
  22. self.zero_point = zero_point
  23. if self.weight_bits != 8:
  24. raise ValueError(
  25. "Currently, only 8-bit weight quantization is supported for "
  26. f"EETQ, but got {self.weight_bits} bits.")
  27. def __repr__(self) -> str:
  28. return (f"EETQConfig(weight_bits={self.weight_bits}, "
  29. f"zero_point={self.zero_point})")
  30. def get_name(self) -> str:
  31. return "eetq"
  32. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  33. return [torch.half]
  34. @classmethod
  35. def get_min_capability(cls) -> int:
  36. # The EETQ kernel only supports Turing or newer GPUs.
  37. return 70
  38. @staticmethod
  39. def get_config_filenames() -> List[str]:
  40. return [
  41. "quant_config.json",
  42. "quantize_config.json",
  43. ]
  44. @classmethod
  45. def from_config(cls, config: Dict[str, Any]) -> "EETQConfig":
  46. weight_bits = cls.get_from_keys(config, ["bits"])
  47. zero_point = cls.get_from_keys(config, ["zero_point"])
  48. return cls(weight_bits, zero_point)
  49. def get_quant_method(self, layer: torch.nn.Module,
  50. prefix: str) -> Optional["EETQLinearMethod"]:
  51. if isinstance(layer, LinearBase):
  52. return EETQLinearMethod(self)
  53. return None
  54. def get_scaled_act_names(self) -> List[str]:
  55. return []
  56. class EETQLinearMethod(LinearMethodBase):
  57. """Linear method for EETQ.
  58. Args:
  59. quant_config: The EETQ quantization config.
  60. """
  61. def __init__(self, quant_config: EETQConfig):
  62. self.quant_config = quant_config
  63. def create_weights(self, layer: torch.nn.Module,
  64. input_size_per_partition: int,
  65. output_partition_sizes: List[int], input_size: int,
  66. output_size: int, params_dtype: torch.dtype,
  67. **extra_weight_attrs):
  68. output_size_per_partition = sum(output_partition_sizes)
  69. qweight = Parameter(torch.empty(input_size_per_partition,
  70. output_size_per_partition,
  71. dtype=torch.int8),
  72. requires_grad=False)
  73. weight_scales = Parameter(torch.empty(output_size_per_partition,
  74. dtype=torch.float16),
  75. requires_grad=False)
  76. set_weight_attrs(qweight, {
  77. "input_dim": 0,
  78. "output_dim": 1,
  79. })
  80. set_weight_attrs(weight_scales, {"input_dim": 0, "output_dim": 0})
  81. layer.register_parameter("qweight", qweight)
  82. set_weight_attrs(qweight, extra_weight_attrs)
  83. layer.register_parameter("weight_scales", weight_scales)
  84. set_weight_attrs(weight_scales, extra_weight_attrs)
  85. def apply(self,
  86. layer: torch.nn.Module,
  87. x: torch.Tensor,
  88. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  89. qweight = layer.qweight.data
  90. weight_scales = layer.weight_scales.data
  91. if HAS_EETQ:
  92. output = w8_a16_gemm(x, qweight, weight_scales)
  93. else:
  94. raise ImportError("You have not installed EETQ. Please refer to "
  95. "https://github.com/NetEase-FuXi/EETQ")
  96. return output
  97. def apply_moe_weights(self, w1: Dict[str,
  98. torch.Tensor], w2: Dict[str,
  99. torch.Tensor],
  100. x: torch.Tensor, gating_output: torch.Tensor,
  101. topk: int, renormalize: bool) -> torch.Tensor:
  102. raise NotImplementedError