瀏覽代碼

make ShardedStateLoader work with HF downloads

AlpinDale 7 月之前
父節點
當前提交
d9d3640165
共有 1 個文件被更改,包括 13 次插入1 次删除
  1. 13 1
      aphrodite/modeling/model_loader/loader.py

+ 13 - 1
aphrodite/modeling/model_loader/loader.py

@@ -420,6 +420,16 @@ class ShardedStateLoader(BaseModelLoader):
                     result[k] = t
                     result[k] = t
         return result
         return result
 
 
+    def _prepare_weights(self, model_name_or_path: str,
+                         revision: Optional[str]):
+        if os.path.isdir(model_name_or_path):
+            return model_name_or_path
+        else:
+            allow_patterns = ["*.safetensors"]
+            return download_weights_from_hf(model_name_or_path,
+                                            self.load_config.download_dir,
+                                            allow_patterns, revision)
+
     def load_model(self, *, model_config: ModelConfig,
     def load_model(self, *, model_config: ModelConfig,
                    device_config: DeviceConfig,
                    device_config: DeviceConfig,
                    lora_config: Optional[LoRAConfig],
                    lora_config: Optional[LoRAConfig],
@@ -430,6 +440,8 @@ class ShardedStateLoader(BaseModelLoader):
         from safetensors.torch import safe_open
         from safetensors.torch import safe_open
 
 
         from aphrodite.distributed import get_tensor_model_parallel_rank
         from aphrodite.distributed import get_tensor_model_parallel_rank
+        local_model_path = self._prepare_weights(model_config.model,
+                                                 model_config.revision)
         with set_default_torch_dtype(model_config.dtype):
         with set_default_torch_dtype(model_config.dtype):
             with torch.device(device_config.device):
             with torch.device(device_config.device):
                 model = _initialize_model(model_config, self.load_config,
                 model = _initialize_model(model_config, self.load_config,
@@ -437,7 +449,7 @@ class ShardedStateLoader(BaseModelLoader):
                                           cache_config)
                                           cache_config)
             rank = get_tensor_model_parallel_rank()
             rank = get_tensor_model_parallel_rank()
             pattern = os.path.join(
             pattern = os.path.join(
-                model_config.model,
+                local_model_path,
                 self.pattern.format(rank=rank, part="*"),
                 self.pattern.format(rank=rank, part="*"),
             )
             )
             filepaths = glob.glob(pattern)
             filepaths = glob.glob(pattern)