|
@@ -1,5 +1,10 @@
|
|
|
+import os
|
|
|
from typing import List, Optional, Set, Tuple, Type
|
|
|
|
|
|
+import huggingface_hub
|
|
|
+from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
|
|
+ HFValidationError, RepositoryNotFoundError)
|
|
|
+from loguru import logger
|
|
|
from torch import nn
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
@@ -103,3 +108,43 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
|
|
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
|
|
|
|
|
raise ValueError(f"{name} is unsupported LoRA weight")
|
|
|
+
|
|
|
+
|
|
|
+def get_adapter_absolute_path(lora_path: str) -> str:
|
|
|
+ """
|
|
|
+ Resolves the given lora_path to an absolute local path.
|
|
|
+ If the lora_path is identified as a Hugging Face model identifier,
|
|
|
+ it will download the model and return the local snapshot path.
|
|
|
+ Otherwise, it treats the lora_path as a local file path and
|
|
|
+ converts it to an absolute path.
|
|
|
+ Parameters:
|
|
|
+ lora_path (str): The path to the lora model, which can be an absolute path,
|
|
|
+ a relative path, or a Hugging Face model identifier.
|
|
|
+ Returns:
|
|
|
+ str: The resolved absolute local path to the lora model.
|
|
|
+ """
|
|
|
+
|
|
|
+ # Check if the path is an absolute path. Return it no matter exists or not.
|
|
|
+ if os.path.isabs(lora_path):
|
|
|
+ return lora_path
|
|
|
+
|
|
|
+ # If the path starts with ~, expand the user home directory.
|
|
|
+ if lora_path.startswith('~'):
|
|
|
+ return os.path.expanduser(lora_path)
|
|
|
+
|
|
|
+ # Check if the expanded relative path exists locally.
|
|
|
+ if os.path.exists(lora_path):
|
|
|
+ return os.path.abspath(lora_path)
|
|
|
+
|
|
|
+ # If the path does not exist locally, assume it's a Hugging Face repo.
|
|
|
+ try:
|
|
|
+ local_snapshot_path = huggingface_hub.snapshot_download(
|
|
|
+ repo_id=lora_path)
|
|
|
+ except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
|
|
|
+ HFValidationError):
|
|
|
+ # Handle errors that may occur during the download
|
|
|
+ # Return original path instead instead of throwing error here
|
|
|
+ logger.exception("Error downloading the HuggingFace model")
|
|
|
+ return lora_path
|
|
|
+
|
|
|
+ return local_snapshot_path
|