|
@@ -420,6 +420,16 @@ class ShardedStateLoader(BaseModelLoader):
|
|
result[k] = t
|
|
result[k] = t
|
|
return result
|
|
return result
|
|
|
|
|
|
|
|
+ def _prepare_weights(self, model_name_or_path: str,
|
|
|
|
+ revision: Optional[str]):
|
|
|
|
+ if os.path.isdir(model_name_or_path):
|
|
|
|
+ return model_name_or_path
|
|
|
|
+ else:
|
|
|
|
+ allow_patterns = ["*.safetensors"]
|
|
|
|
+ return download_weights_from_hf(model_name_or_path,
|
|
|
|
+ self.load_config.download_dir,
|
|
|
|
+ allow_patterns, revision)
|
|
|
|
+
|
|
def load_model(self, *, model_config: ModelConfig,
|
|
def load_model(self, *, model_config: ModelConfig,
|
|
device_config: DeviceConfig,
|
|
device_config: DeviceConfig,
|
|
lora_config: Optional[LoRAConfig],
|
|
lora_config: Optional[LoRAConfig],
|
|
@@ -430,6 +440,8 @@ class ShardedStateLoader(BaseModelLoader):
|
|
from safetensors.torch import safe_open
|
|
from safetensors.torch import safe_open
|
|
|
|
|
|
from aphrodite.distributed import get_tensor_model_parallel_rank
|
|
from aphrodite.distributed import get_tensor_model_parallel_rank
|
|
|
|
+ local_model_path = self._prepare_weights(model_config.model,
|
|
|
|
+ model_config.revision)
|
|
with set_default_torch_dtype(model_config.dtype):
|
|
with set_default_torch_dtype(model_config.dtype):
|
|
with torch.device(device_config.device):
|
|
with torch.device(device_config.device):
|
|
model = _initialize_model(model_config, self.load_config,
|
|
model = _initialize_model(model_config, self.load_config,
|
|
@@ -437,7 +449,7 @@ class ShardedStateLoader(BaseModelLoader):
|
|
cache_config)
|
|
cache_config)
|
|
rank = get_tensor_model_parallel_rank()
|
|
rank = get_tensor_model_parallel_rank()
|
|
pattern = os.path.join(
|
|
pattern = os.path.join(
|
|
- model_config.model,
|
|
|
|
|
|
+ local_model_path,
|
|
self.pattern.format(rank=rank, part="*"),
|
|
self.pattern.format(rank=rank, part="*"),
|
|
)
|
|
)
|
|
filepaths = glob.glob(pattern)
|
|
filepaths = glob.glob(pattern)
|