Преглед на файлове

speedup lora loading times by resuing the cpu dummy lora

AlpinDale преди 8 месеца
родител
ревизия
b55381df0e
променени са 3 файла, в които са добавени 46 реда и са изтрити 18 реда
  1. 9 0
      aphrodite/lora/models.py
  2. 22 4
      aphrodite/lora/worker_manager.py
  3. 15 14
      aphrodite/task_handler/model_runner.py

+ 9 - 0
aphrodite/lora/models.py

@@ -118,6 +118,15 @@ class LoRAModel:
         self.rank = rank
         self.loras: Dict[str, LoRALayerWeights] = loras
 
+    def clone(self, lora_model_id: int) -> "LoRAModel":
+        """Return a copy of the object with different ids.
+        Will share the underlying tensors."""
+        return self.__class__(
+            lora_model_id,
+            rank=self.rank,
+            loras=self.loras.copy(),
+        )
+
     @property
     def extra_vocab_size(self) -> int:
         return max(lora.extra_vocab_size

+ 22 - 4
aphrodite/lora/worker_manager.py

@@ -1,5 +1,6 @@
 from abc import ABC, abstractmethod, abstractproperty
-from typing import Any, Dict, List, Set, Type
+from contextlib import contextmanager
+from typing import Any, Dict, List, Literal, Set, Type, Union
 
 import torch
 
@@ -23,6 +24,17 @@ class AbstractWorkerLoRAManager(ABC):
         self.device = device
         self.lora_config = lora_config
 
+        # If False, do not cache. If None, cache is empty.
+        self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
+
+    @contextmanager
+    def dummy_lora_cache(self):
+        """Use this context manager to reuse the dummy lora model
+        to avoid creating it repeatedly."""
+        self._cached_dummy_lora = None
+        yield
+        self._cached_dummy_lora = False
+
     @abstractproperty
     def is_enabled(self) -> bool:
         ...
@@ -172,9 +184,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
     def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
         if lora_request.lora_int_id in self.list_loras():
             return False
-        return self._lora_manager.add_lora(
-            self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
-                                                 rank, self.embedding_modules))
+        if isinstance(self._cached_dummy_lora, LoRAModel):
+            dummy_lora = self._cached_dummy_lora.clone(
+                lora_request.lora_int_id)
+        else:
+            dummy_lora = self._lora_manager.create_dummy_lora(
+                lora_request.lora_int_id, rank, self.embedding_modules)
+            if self._cached_dummy_lora is None:
+                self._cached_dummy_lora = dummy_lora
+        return self._lora_manager.add_lora(dummy_lora)
 
     def add_lora(self, lora_request: LoRARequest) -> bool:
         if lora_request.lora_int_id in self.list_loras():

+ 15 - 14
aphrodite/task_handler/model_runner.py

@@ -841,20 +841,21 @@ class ModelRunner:
         dummy_lora_requests = []
         dummy_lora_requests_per_seq = []
         if self.lora_config:
-            for idx in range(self.lora_config.max_loras):
-                lora_id = idx + 1
-                dummy_lora_request = LoRARequest(
-                    lora_name=f"warmup_{lora_id}",
-                    lora_int_id=lora_id,
-                    lora_local_path="/not/a/real/path",
-                )
-                self.lora_manager.add_dummy_lora(dummy_lora_request,
-                                                 rank=LORA_WARMUP_RANK)
-                dummy_lora_requests.append(dummy_lora_request)
-            dummy_lora_requests_per_seq = [
-                dummy_lora_requests[idx % len(dummy_lora_requests)]
-                for idx in range(max_num_seqs)
-            ]
+            with self.lora_manager.dummy_lora_cache():
+                for idx in range(self.lora_config.max_loras):
+                    lora_id = idx + 1
+                    dummy_lora_request = LoRARequest(
+                        lora_name=f"warmup_{lora_id}",
+                        lora_int_id=lora_id,
+                        lora_local_path="/not/a/real/path",
+                    )
+                    self.lora_manager.add_dummy_lora(dummy_lora_request,
+                                                     rank=LORA_WARMUP_RANK)
+                    dummy_lora_requests.append(dummy_lora_request)
+                dummy_lora_requests_per_seq = [
+                    dummy_lora_requests[idx % len(dummy_lora_requests)]
+                    for idx in range(max_num_seqs)
+                ]
 
         # Profile memory usage with max_num_sequences sequences and the total
         # number of tokens equal to max_num_batched_tokens.