Prechádzať zdrojové kódy

automatically detect sparseml models

AlpinDale 7 mesiacov pred
rodič
commit
072aec1062
1 zmenil súbory, kde vykonal 13 pridanie a 2 odobranie
  1. 13 2
      aphrodite/common/config.py

+ 13 - 2
aphrodite/common/config.py

@@ -172,6 +172,17 @@ class ModelConfig:
         self.embedding_mode = any(
             ModelRegistry.is_embedding_model(arch) for arch in architectures)
 
+    def _parse_quant_hf_config(self):
+        quant_cfg = getattr(self.hf_config, "quantization_config", None)
+        if quant_cfg is None:
+            # SparseML uses a "compression_config" with a "quantization_config".
+            compression_cfg = getattr(self.hf_config, "compression_config",
+                                      None)
+            if compression_cfg is not None:
+                quant_cfg = compression_cfg.get("quantization_config", None)
+
+        return quant_cfg
+
     def _verify_quantization(self) -> None:
         supported_quantization = [*QUANTIZATION_METHODS]
         rocm_supported_quantization = ["gptq", "squeezellm"]
@@ -179,12 +190,12 @@ class ModelConfig:
             self.quantization = self.quantization.lower()
 
         # Parse quantization method from the HF model config, if available.
-        quant_cfg = getattr(self.hf_config, "quantization_config", None)
+        quant_cfg = self._parse_quant_hf_config()
         if quant_cfg is not None:
             quant_method = quant_cfg.get("quant_method", "").lower()
 
             # Detect which checkpoint is it
-            for name, method in QUANTIZATION_METHODS.items():
+            for _, method in QUANTIZATION_METHODS.items():
                 quantization_override = method.override_quantization_method(
                     quant_cfg, self.quantization)
                 if quantization_override: