|
@@ -1,5 +1,6 @@
|
|
|
from abc import ABC, abstractmethod, abstractproperty
|
|
|
-from typing import Any, Dict, List, Set, Type
|
|
|
+from contextlib import contextmanager
|
|
|
+from typing import Any, Dict, List, Literal, Set, Type, Union
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -23,6 +24,17 @@ class AbstractWorkerLoRAManager(ABC):
|
|
|
self.device = device
|
|
|
self.lora_config = lora_config
|
|
|
|
|
|
+ # If False, do not cache. If None, cache is empty.
|
|
|
+ self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
|
|
|
+
|
|
|
+ @contextmanager
|
|
|
+ def dummy_lora_cache(self):
|
|
|
+ """Use this context manager to reuse the dummy lora model
|
|
|
+ to avoid creating it repeatedly."""
|
|
|
+ self._cached_dummy_lora = None
|
|
|
+ yield
|
|
|
+ self._cached_dummy_lora = False
|
|
|
+
|
|
|
@abstractproperty
|
|
|
def is_enabled(self) -> bool:
|
|
|
...
|
|
@@ -172,9 +184,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
|
|
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
|
|
if lora_request.lora_int_id in self.list_loras():
|
|
|
return False
|
|
|
- return self._lora_manager.add_lora(
|
|
|
- self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
|
|
|
- rank, self.embedding_modules))
|
|
|
+ if isinstance(self._cached_dummy_lora, LoRAModel):
|
|
|
+ dummy_lora = self._cached_dummy_lora.clone(
|
|
|
+ lora_request.lora_int_id)
|
|
|
+ else:
|
|
|
+ dummy_lora = self._lora_manager.create_dummy_lora(
|
|
|
+ lora_request.lora_int_id, rank, self.embedding_modules)
|
|
|
+ if self._cached_dummy_lora is None:
|
|
|
+ self._cached_dummy_lora = dummy_lora
|
|
|
+ return self._lora_manager.add_lora(dummy_lora)
|
|
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
|
if lora_request.lora_int_id in self.list_loras():
|