1
0
Эх сурвалжийг харах

feat: support loading lora adapters directly from HF

AlpinDale 6 сар өмнө
parent
commit
f92b9fc820

+ 2 - 2
aphrodite/endpoints/openai/serving_engine.py

@@ -40,7 +40,7 @@ class PromptAdapterPath:
 @dataclass
 class LoRAModulePath:
     name: str
-    local_path: str
+    path: str
 
 
 AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
@@ -80,7 +80,7 @@ class OpenAIServing:
                 LoRARequest(
                     lora_name=lora.name,
                     lora_int_id=i,
-                    lora_local_path=lora.local_path,
+                    lora_path=lora.path,
                 ) for i, lora in enumerate(lora_modules, start=1)
             ]
 

+ 39 - 3
aphrodite/lora/request.py

@@ -1,4 +1,5 @@
-from dataclasses import dataclass
+import warnings
+from dataclasses import dataclass, field
 from typing import Optional
 
 from aphrodite.adapter_commons.request import AdapterRequest
@@ -20,10 +21,25 @@ class LoRARequest:
 
     lora_name: str
     lora_int_id: int
-    lora_local_path: str
+    lora_path: str = ""
+    lora_local_path: Optional[str] = field(default=None, repr=False)
     long_lora_max_len: Optional[int] = None
     __hash__ = AdapterRequest.__hash__
 
+    def __post_init__(self):
+        if 'lora_local_path' in self.__dict__:
+            warnings.warn(
+                "The 'lora_local_path' attribute is deprecated "
+                "and will be removed in a future version. "
+                "Please use 'lora_path' instead.",
+                DeprecationWarning,
+                stacklevel=2)
+            if not self.lora_path:
+                self.lora_path = self.lora_local_path or ""
+
+        # Ensure lora_path is not empty
+        assert self.lora_path, "lora_path cannot be empty"
+
     @property
     def adapter_id(self):
         return self.lora_int_id
@@ -32,6 +48,26 @@ class LoRARequest:
     def name(self):
         return self.lora_name
 
+    @property
+    def path(self):
+        return self.lora_path
+
     @property
     def local_path(self):
-        return self.lora_local_path
+        warnings.warn(
+            "The 'local_path' attribute is deprecated "
+            "and will be removed in a future version. "
+            "Please use 'path' instead.",
+            DeprecationWarning,
+            stacklevel=2)
+        return self.lora_path
+
+    @local_path.setter
+    def local_path(self, value):
+        warnings.warn(
+            "The 'local_path' attribute is deprecated "
+            "and will be removed in a future version. "
+            "Please use 'path' instead.",
+            DeprecationWarning,
+            stacklevel=2)
+        self.lora_path = value

+ 45 - 0
aphrodite/lora/utils.py

@@ -1,5 +1,10 @@
+import os
 from typing import List, Optional, Set, Tuple, Type
 
+import huggingface_hub
+from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
+                                   HFValidationError, RepositoryNotFoundError)
+from loguru import logger
 from torch import nn
 from transformers import PretrainedConfig
 
@@ -103,3 +108,43 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
             return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
 
     raise ValueError(f"{name} is unsupported LoRA weight")
+
+
+def get_adapter_absolute_path(lora_path: str) -> str:
+    """
+    Resolves the given lora_path to an absolute local path.
+    If the lora_path is identified as a Hugging Face model identifier,
+    it will download the model and return the local snapshot path.
+    Otherwise, it treats the lora_path as a local file path and
+    converts it to an absolute path.
+    Parameters:
+    lora_path (str): The path to the lora model, which can be an absolute path,
+                     a relative path, or a Hugging Face model identifier.
+    Returns:
+    str: The resolved absolute local path to the lora model.
+    """
+
+    # Check if the path is an absolute path. Return it no matter exists or not.
+    if os.path.isabs(lora_path):
+        return lora_path
+
+    # If the path starts with ~, expand the user home directory.
+    if lora_path.startswith('~'):
+        return os.path.expanduser(lora_path)
+
+    # Check if the expanded relative path exists locally.
+    if os.path.exists(lora_path):
+        return os.path.abspath(lora_path)
+
+    # If the path does not exist locally, assume it's a Hugging Face repo.
+    try:
+        local_snapshot_path = huggingface_hub.snapshot_download(
+            repo_id=lora_path)
+    except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
+            HFValidationError):
+        # Handle errors that may occur during the download
+        # Return original path instead instead of throwing error here
+        logger.exception("Error downloading the HuggingFace model")
+        return lora_path
+
+    return local_snapshot_path

+ 6 - 4
aphrodite/lora/worker_manager.py

@@ -13,6 +13,7 @@ from aphrodite.lora.models import (LoRAModel, LoRAModelManager,
                                    LRUCacheLoRAModelManager,
                                    create_lora_manager)
 from aphrodite.lora.request import LoRARequest
+from aphrodite.lora.utils import get_adapter_absolute_path
 
 
 class WorkerLoRAManager(AbstractWorkerManager):
@@ -87,8 +88,9 @@ class WorkerLoRAManager(AbstractWorkerManager):
                         packed_modules_mapping[module])
                 else:
                     expected_lora_modules.append(module)
+            lora_path = get_adapter_absolute_path(lora_request.lora_path)
             lora = self._lora_model_cls.from_local_checkpoint(
-                lora_request.lora_local_path,
+                lora_path,
                 expected_lora_modules,
                 max_position_embeddings=self.max_position_embeddings,
                 lora_model_id=lora_request.lora_int_id,
@@ -100,12 +102,12 @@ class WorkerLoRAManager(AbstractWorkerManager):
                 embedding_padding_modules=self.embedding_padding_modules,
             )
         except Exception as e:
-            raise RuntimeError(
-                f"Loading lora {lora_request.lora_local_path} failed") from e
+            raise RuntimeError(f"Loading lora {lora_path} failed") from e
         if lora.rank > self.lora_config.max_lora_rank:
             raise ValueError(
                 f"LoRA rank {lora.rank} is greater than max_lora_rank "
-                f"{self.lora_config.max_lora_rank}.")
+                f"{self.lora_config.max_lora_rank}. Please launch the "
+                "engine with a higher max_lora_rank (--max-lora-rank).")
         if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
             raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
                              f"is greater than lora_extra_vocab_size "

+ 4 - 6
aphrodite/transformers_utils/tokenizer.py

@@ -134,15 +134,13 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
     if lora_request is None:
         return None
     try:
-        tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
-                                  **kwargs)
+        tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
     except OSError as e:
         # No tokenizer was found in the LoRA folder,
         # use base model tokenizer
-        logger.warning(
-            f"No tokenizer found in {lora_request.lora_local_path}, "
-            "using base model tokenizer instead. "
-            f"(Exception: {str(e)})")
+        logger.warning(f"No tokenizer found in {lora_request.lora_path}, "
+                       "using base model tokenizer instead. "
+                       f"(Exception: {str(e)})")
         tokenizer = None
     return tokenizer