1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- import os
- from importlib.util import find_spec
- from typing import Any, Dict, List, Optional
- from torch.nn import Module
- from aphrodite.quantization.base_config import QuantizationConfig
- SUPPORTED_QUANT_DTYPE_LIST = ["s8", "f8e4m3fn"]
- class NeuronQuantConfig(QuantizationConfig):
- """Int8 Quantization Config class for Neuron Backend."""
- def __init__(
- self,
- dequant_dtype: str = "f16",
- quantize_method: str = "vector_dynamic",
- ) -> None:
- self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
- if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
- raise ValueError(
- f"Neuron quantization datatype {self.quant_dtype} is not valid,"
- f"the quantization datatype should match one of the below types"
- f"{SUPPORTED_QUANT_DTYPE_LIST}"
- )
- self.dequant_dtype = dequant_dtype
- self.quantize_method = quantize_method
- def get_name(self) -> str:
- return "neuron_quant"
- def get_supported_act_dtypes(self) -> List[str]:
- return SUPPORTED_QUANT_DTYPE_LIST
- @classmethod
- def get_min_capability(cls) -> int:
- raise NotImplementedError(
- "This function should not be called with Neuron Backend"
- )
- @staticmethod
- def get_config_filenames() -> List[str]:
- return []
- @classmethod
- def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig":
- quantize_method = cls.get_from_keys(config, ["quantize_method"])
- dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
- return cls(dequant_dtype=dequant_dtype, quantize_method=quantize_method)
- def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
- if find_spec("transformers_neuronx") is not None:
- return self.get_quantization_config()
- else:
- raise NotImplementedError(
- "Neuron Quantization is only supported through"
- " transformers_neuronx."
- )
- def get_scaled_act_names(self) -> List[str]:
- return []
- def get_quantization_config(self):
- from transformers_neuronx.config import QuantizationConfig
- return QuantizationConfig(
- quant_dtype=self.quant_dtype,
- dequant_dtype=self.dequant_dtype,
- quantize_method=self.quantize_method,
- )
|