eetq.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.modeling.utils import set_weight_attrs
  7. from aphrodite.quantization.base_config import QuantizationConfig
  8. HAS_EETQ = False
  9. with suppress(ImportError):
  10. from eetq import w8_a16_gemm
  11. HAS_EETQ = True
  12. class EETQConfig(QuantizationConfig):
  13. """Config class for eetq.
  14. https://github.com/NetEase-FuXi/EETQ/tree/main
  15. """
  16. def __init__(
  17. self,
  18. weight_bits: int,
  19. zero_point: bool,
  20. ) -> None:
  21. self.weight_bits = weight_bits
  22. self.zero_point = zero_point
  23. if self.weight_bits != 8:
  24. raise ValueError(
  25. "Currently, only 8-bit weight quantization is supported for "
  26. f"EETQ, but got {self.weight_bits} bits.")
  27. def __repr__(self) -> str:
  28. return (f"EETQConfig(weight_bits={self.weight_bits}, "
  29. f"zero_point={self.zero_point})")
  30. def get_name(self) -> str:
  31. return "eetq"
  32. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  33. return [torch.half]
  34. def get_min_capability(self) -> int:
  35. # The EETQ kernel only supports Turing or newer GPUs.
  36. return 70
  37. @staticmethod
  38. def get_config_filenames() -> List[str]:
  39. return [
  40. "quant_config.json",
  41. "quantize_config.json",
  42. ]
  43. @classmethod
  44. def from_config(cls, config: Dict[str, Any]) -> "EETQConfig":
  45. weight_bits = cls.get_from_keys(config, ["bits"])
  46. zero_point = cls.get_from_keys(config, ["zero_point"])
  47. return cls(weight_bits, zero_point)
  48. def get_quant_method(
  49. self, layer: torch.nn.Module) -> Optional["EETQLinearMethod"]:
  50. if isinstance(layer, LinearBase):
  51. return EETQLinearMethod(self)
  52. return None
  53. def get_scaled_act_names(self) -> List[str]:
  54. return []
  55. class EETQLinearMethod(LinearMethodBase):
  56. """Linear method for EETQ.
  57. Args:
  58. quant_config: The EETQ quantization config.
  59. """
  60. def __init__(self, quant_config: EETQConfig):
  61. self.quant_config = quant_config
  62. def create_weights(self, layer: torch.nn.Module,
  63. input_size_per_partition: int,
  64. output_partition_sizes: List[int], input_size: int,
  65. output_size: int, params_dtype: torch.dtype,
  66. **extra_weight_attrs):
  67. output_size_per_partition = sum(output_partition_sizes)
  68. qweight = Parameter(torch.empty(input_size_per_partition,
  69. output_size_per_partition,
  70. dtype=torch.int8),
  71. requires_grad=False)
  72. weight_scales = Parameter(torch.empty(output_size_per_partition,
  73. dtype=torch.float16),
  74. requires_grad=False)
  75. set_weight_attrs(qweight, {
  76. "input_dim": 0,
  77. "output_dim": 1,
  78. })
  79. set_weight_attrs(weight_scales, {"input_dim": 0, "output_dim": 0})
  80. layer.register_parameter("qweight", qweight)
  81. set_weight_attrs(qweight, extra_weight_attrs)
  82. layer.register_parameter("weight_scales", weight_scales)
  83. set_weight_attrs(weight_scales, extra_weight_attrs)
  84. def apply(self,
  85. layer: torch.nn.Module,
  86. x: torch.Tensor,
  87. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  88. qweight = layer.qweight.data
  89. weight_scales = layer.weight_scales.data
  90. if HAS_EETQ:
  91. output = w8_a16_gemm(x, qweight, weight_scales)
  92. else:
  93. raise ImportError("You have not installed EETQ. Please refer to "
  94. "https://github.com/NetEase-FuXi/EETQ")
  95. return output
  96. def apply_moe_weights(self, w1: Dict[str,
  97. torch.Tensor], w2: Dict[str,
  98. torch.Tensor],
  99. x: torch.Tensor, gating_output: torch.Tensor,
  100. topk: int, renormalize: bool) -> torch.Tensor:
  101. raise NotImplementedError