ソースを参照

fix mistral v0.3 weight loading

AlpinDale 7 ヶ月 前
コミット
7d0884de9a

+ 16 - 2
aphrodite/modeling/model_loader/loader.py

@@ -21,7 +21,8 @@ from aphrodite.modeling.model_loader.tensorizer import (
 from aphrodite.modeling.model_loader.utils import (get_model_architecture,
                                                    set_default_torch_dtype)
 from aphrodite.modeling.model_loader.weight_utils import (
-    download_weights_from_hf, filter_files_not_needed_for_inference,
+    download_safetensors_index_file_from_hf, download_weights_from_hf,
+    filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
     get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
     pt_weights_iterator, safetensors_weights_iterator)
 from aphrodite.modeling.models.vlm_base import VisionLanguageModelBase
@@ -186,7 +187,19 @@ class DefaultModelLoader(BaseModelLoader):
                     use_safetensors = True
                 break
 
-        if not use_safetensors:
+        if use_safetensors:
+            # For models like Mistral-7B-Instruct-v0.3
+            # there are both sharded safetensors files and a consolidated
+            # safetensors file. Using both breaks.
+            # Here, we download the `model.safetensors.index.json` and filter
+            # any files not found in the index.
+            if not is_local:
+                download_safetensors_index_file_from_hf(
+                    model_name_or_path, self.load_config.download_dir,
+                    revision)
+            hf_weights_files = filter_duplicate_safetensors_files(
+                hf_weights_files, hf_folder)
+        else:
             hf_weights_files = filter_files_not_needed_for_inference(
                 hf_weights_files)
 
@@ -490,6 +503,7 @@ class ShardedStateLoader(BaseModelLoader):
         max_size: Optional[int] = None,
     ) -> None:
         from safetensors.torch import save_file
+
         from aphrodite.distributed import get_tensor_model_parallel_rank
         if pattern is None:
             pattern = ShardedStateLoader.DEFAULT_PATTERN

+ 64 - 4
aphrodite/modeling/model_loader/weight_utils.py

@@ -12,14 +12,14 @@ import filelock
 import huggingface_hub.constants
 import numpy as np
 import torch
-from huggingface_hub import HfFileSystem, snapshot_download
+from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
+from loguru import logger
 from safetensors.torch import load_file, safe_open, save_file
 from tqdm.auto import tqdm
-from loguru import logger
+from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
 
 from aphrodite.common.config import LoadConfig, ModelConfig
-from aphrodite.quantization import (QuantizationConfig,
-                                    get_quantization_config)
+from aphrodite.quantization import QuantizationConfig, get_quantization_config
 from aphrodite.quantization.schema import QuantParamSchema
 
 # use system-level temp directory for file locks, so that multiple users
@@ -215,6 +215,66 @@ def download_weights_from_hf(
     return hf_folder
 
 
+def download_safetensors_index_file_from_hf(
+    model_name_or_path: str,
+    cache_dir: Optional[str],
+    revision: Optional[str] = None,
+) -> None:
+    """Download hf safetensors index file from Hugging Face Hub.
+    Args:
+        model_name_or_path (str): The model name or path.
+        cache_dir (Optional[str]): The cache directory to store the model
+            weights. If None, will use HF defaults.
+        revision (Optional[str]): The revision of the model.
+    """
+    # Use file lock to prevent multiple processes from
+    # downloading the same model weights at the same time.
+    with get_lock(model_name_or_path, cache_dir):
+        try:
+            # Download the safetensors index file.
+            hf_hub_download(
+                repo_id=model_name_or_path,
+                filename=SAFE_WEIGHTS_INDEX_NAME,
+                cache_dir=cache_dir,
+                revision=revision,
+                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
+            )
+        # If file not found on remote or locally, we should not fail since
+        # only some models will have SAFE_WEIGHTS_INDEX_NAME.
+        except huggingface_hub.utils.EntryNotFoundError:
+            logger.info(f"No {SAFE_WEIGHTS_INDEX_NAME} found in remote.")
+        except huggingface_hub.utils.LocalEntryNotFoundError:
+            logger.info(f"No {SAFE_WEIGHTS_INDEX_NAME} found in local cache.")
+
+
+# For models like Mistral-7B-v0.3, there are both sharded
+# safetensors files and a consolidated safetensors file.
+# Passing both of these to the weight loader functionality breaks.
+# So, we use the SAFE_WEIGHTS_INDEX_NAME to
+# look up which safetensors files should be used.
+def filter_duplicate_safetensors_files(hf_weights_files: List[str],
+                                       hf_folder: str) -> List[str]:
+    # model.safetensors.index.json is a mapping from keys in the
+    # torch state_dict to safetensors file holding that weight.
+    index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
+    if not os.path.isfile(index_file_name):
+        return hf_weights_files
+
+    # Iterate through the weight_map (weight_name: safetensors files)
+    # to identify weights that we should use.
+    with open(index_file_name) as index_file:
+        weight_map = json.load(index_file)["weight_map"]
+    weight_files_in_index = set()
+    for weight_name in weight_map:
+        weight_files_in_index.add(
+            os.path.join(hf_folder, weight_map[weight_name]))
+    # Filter out any fields that are not found in the index file.
+    hf_weights_files = [
+        f for f in hf_weights_files if f in weight_files_in_index
+    ]
+    return hf_weights_files
+
+
 def filter_files_not_needed_for_inference(
         hf_weights_files: List[str]) -> List[str]:
     """