eetq.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from typing import Any, Dict, List, Optional
  2. from contextlib import suppress
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
  6. from aphrodite.quantization.base_config import \
  7. QuantizationConfig
  8. HAS_EETQ = False
  9. with suppress(ImportError):
  10. from eetq import w8_a16_gemm
  11. HAS_EETQ = True
  12. class EETQConfig(QuantizationConfig):
  13. """Config class for eetq.
  14. https://github.com/NetEase-FuXi/EETQ/tree/main
  15. """
  16. def __init__(
  17. self,
  18. weight_bits: int,
  19. zero_point: bool,
  20. ) -> None:
  21. self.weight_bits = weight_bits
  22. self.zero_point = zero_point
  23. if self.weight_bits != 8:
  24. raise ValueError(
  25. "Currently, only 8-bit weight quantization is supported for "
  26. f"EETQ, but got {self.weight_bits} bits.")
  27. def __repr__(self) -> str:
  28. return (f"EETQConfig(weight_bits={self.weight_bits}, "
  29. f"zero_point={self.zero_point})")
  30. def get_name(self) -> str:
  31. return "eetq"
  32. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  33. return [torch.half]
  34. def get_min_capability(self) -> int:
  35. # The EETQ kernel only supports Turing or newer GPUs.
  36. return 70
  37. @staticmethod
  38. def get_config_filenames() -> List[str]:
  39. return [
  40. "quant_config.json",
  41. "quantize_config.json",
  42. ]
  43. @classmethod
  44. def from_config(cls, config: Dict[str, Any]) -> "EETQConfig":
  45. weight_bits = cls.get_from_keys(config, ["bits"])
  46. zero_point = cls.get_from_keys(config, ["zero_point"])
  47. return cls(weight_bits, zero_point)
  48. def get_linear_method(self) -> "EETQLinearMethod":
  49. return EETQLinearMethod(self)
  50. def get_scaled_act_names(self) -> List[str]:
  51. return []
  52. class EETQLinearMethod(LinearMethodBase):
  53. """Linear method for EETQ.
  54. Args:
  55. quant_config: The EETQ quantization config.
  56. """
  57. def __init__(self, quant_config: EETQConfig):
  58. self.quant_config = quant_config
  59. def create_weights(self, layer: torch.nn.Module,
  60. input_size_per_partition: int,
  61. output_partition_sizes: List[int], input_size: int,
  62. output_size: int, params_dtype: torch.dtype,
  63. **extra_weight_attrs):
  64. output_size_per_partition = sum(output_partition_sizes)
  65. qweight = Parameter(torch.empty(input_size_per_partition,
  66. output_size_per_partition,
  67. dtype=torch.int8),
  68. requires_grad=False)
  69. weight_scales = Parameter(torch.empty(output_size_per_partition,
  70. dtype=torch.float16),
  71. requires_grad=False)
  72. set_weight_attrs(qweight, {
  73. "input_dim": 0,
  74. "output_dim": 1,
  75. })
  76. set_weight_attrs(weight_scales, {"input_dim": 0, "output_dim": 0})
  77. layer.register_parameter("qweight", qweight)
  78. set_weight_attrs(qweight, extra_weight_attrs)
  79. layer.register_parameter("weight_scales", weight_scales)
  80. set_weight_attrs(weight_scales, extra_weight_attrs)
  81. def apply_weights(self,
  82. layer: torch.nn.Module,
  83. x: torch.Tensor,
  84. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  85. qweight = layer.qweightdata
  86. weight_scales = layer.weight_scales.data
  87. if HAS_EETQ:
  88. output = w8_a16_gemm(x, qweight, weight_scales)
  89. else:
  90. raise ImportError("You have not installed EETQ. Please refer to "
  91. "https://github.com/NetEase-FuXi/EETQ")
  92. return output
  93. def apply_moe_weights(self, w1: Dict[str,
  94. torch.Tensor], w2: Dict[str,
  95. torch.Tensor],
  96. x: torch.Tensor, gating_output: torch.Tensor,
  97. topk: int, renormalize: bool) -> torch.Tensor:
  98. raise NotImplementedError