|
@@ -172,6 +172,17 @@ class ModelConfig:
|
|
|
self.embedding_mode = any(
|
|
|
ModelRegistry.is_embedding_model(arch) for arch in architectures)
|
|
|
|
|
|
+ def _parse_quant_hf_config(self):
|
|
|
+ quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
|
|
+ if quant_cfg is None:
|
|
|
+ # SparseML uses a "compression_config" with a "quantization_config".
|
|
|
+ compression_cfg = getattr(self.hf_config, "compression_config",
|
|
|
+ None)
|
|
|
+ if compression_cfg is not None:
|
|
|
+ quant_cfg = compression_cfg.get("quantization_config", None)
|
|
|
+
|
|
|
+ return quant_cfg
|
|
|
+
|
|
|
def _verify_quantization(self) -> None:
|
|
|
supported_quantization = [*QUANTIZATION_METHODS]
|
|
|
rocm_supported_quantization = ["gptq", "squeezellm"]
|
|
@@ -179,12 +190,12 @@ class ModelConfig:
|
|
|
self.quantization = self.quantization.lower()
|
|
|
|
|
|
# Parse quantization method from the HF model config, if available.
|
|
|
- quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
|
|
+ quant_cfg = self._parse_quant_hf_config()
|
|
|
if quant_cfg is not None:
|
|
|
quant_method = quant_cfg.get("quant_method", "").lower()
|
|
|
|
|
|
# Detect which checkpoint is it
|
|
|
- for name, method in QUANTIZATION_METHODS.items():
|
|
|
+ for _, method in QUANTIZATION_METHODS.items():
|
|
|
quantization_override = method.override_quantization_method(
|
|
|
quant_cfg, self.quantization)
|
|
|
if quantization_override:
|