import os from importlib.util import find_spec from typing import Any, Dict, List, Optional from torch.nn import Module from aphrodite.quantization.base_config import QuantizationConfig SUPPORTED_QUANT_DTYPE_LIST = ["s8", "f8e4m3fn"] class NeuronQuantConfig(QuantizationConfig): """Int8 Quantization Config class for Neuron Backend.""" def __init__( self, dequant_dtype: str = "f16", quantize_method: str = "vector_dynamic", ) -> None: self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: raise ValueError( f"Neuron quantization datatype {self.quant_dtype} is not valid," f"the quantization datatype should match one of the below types" f"{SUPPORTED_QUANT_DTYPE_LIST}" ) self.dequant_dtype = dequant_dtype self.quantize_method = quantize_method def get_name(self) -> str: return "neuron_quant" def get_supported_act_dtypes(self) -> List[str]: return SUPPORTED_QUANT_DTYPE_LIST @classmethod def get_min_capability(cls) -> int: raise NotImplementedError( "This function should not be called with Neuron Backend" ) @staticmethod def get_config_filenames() -> List[str]: return [] @classmethod def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig": quantize_method = cls.get_from_keys(config, ["quantize_method"]) dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) return cls(dequant_dtype=dequant_dtype, quantize_method=quantize_method) def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: if find_spec("transformers_neuronx") is not None: return self.get_quantization_config() else: raise NotImplementedError( "Neuron Quantization is only supported through" " transformers_neuronx." ) def get_scaled_act_names(self) -> List[str]: return [] def get_quantization_config(self): from transformers_neuronx.config import QuantizationConfig return QuantizationConfig( quant_dtype=self.quant_dtype, dequant_dtype=self.dequant_dtype, quantize_method=self.quantize_method, )