neuron_quant.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import os
  2. from importlib.util import find_spec
  3. from typing import Any, Dict, List, Optional
  4. from torch.nn import Module
  5. from aphrodite.quantization.base_config import QuantizationConfig
  6. SUPPORTED_QUANT_DTYPE_LIST = ["s8", "f8e4m3fn"]
  7. class NeuronQuantConfig(QuantizationConfig):
  8. """Int8 Quantization Config class for Neuron Backend."""
  9. def __init__(
  10. self,
  11. dequant_dtype: str = "f16",
  12. quantize_method: str = "vector_dynamic",
  13. ) -> None:
  14. self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
  15. if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
  16. raise ValueError(
  17. f"Neuron quantization datatype {self.quant_dtype} is not valid,"
  18. f"the quantization datatype should match one of the below types"
  19. f"{SUPPORTED_QUANT_DTYPE_LIST}"
  20. )
  21. self.dequant_dtype = dequant_dtype
  22. self.quantize_method = quantize_method
  23. def get_name(self) -> str:
  24. return "neuron_quant"
  25. def get_supported_act_dtypes(self) -> List[str]:
  26. return SUPPORTED_QUANT_DTYPE_LIST
  27. @classmethod
  28. def get_min_capability(cls) -> int:
  29. raise NotImplementedError(
  30. "This function should not be called with Neuron Backend"
  31. )
  32. @staticmethod
  33. def get_config_filenames() -> List[str]:
  34. return []
  35. @classmethod
  36. def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig":
  37. quantize_method = cls.get_from_keys(config, ["quantize_method"])
  38. dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
  39. return cls(dequant_dtype=dequant_dtype, quantize_method=quantize_method)
  40. def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
  41. if find_spec("transformers_neuronx") is not None:
  42. return self.get_quantization_config()
  43. else:
  44. raise NotImplementedError(
  45. "Neuron Quantization is only supported through"
  46. " transformers_neuronx."
  47. )
  48. def get_scaled_act_names(self) -> List[str]:
  49. return []
  50. def get_quantization_config(self):
  51. from transformers_neuronx.config import QuantizationConfig
  52. return QuantizationConfig(
  53. quant_dtype=self.quant_dtype,
  54. dequant_dtype=self.dequant_dtype,
  55. quantize_method=self.quantize_method,
  56. )