Browse Source

lora: add scaling factor support for LoRA at runtime

AlpinDale 3 months ago
parent
commit
e49018808a
3 changed files with 23 additions and 3 deletions
  1. 3 2
      aphrodite/lora/lora.py
  2. 15 1
      aphrodite/lora/models.py
  3. 5 0
      aphrodite/lora/request.py

+ 3 - 2
aphrodite/lora/lora.py

@@ -18,6 +18,7 @@ class LoRALayerWeights:
         lora_b: torch.Tensor,
         embeddings_tensor: Optional[torch.Tensor] = None,
         scaling: Optional[float] = None,
+        runtime_scaling: Optional[float] = 1.0,
     ) -> None:
         self.module_name = module_name
         self.rank = rank
@@ -27,9 +28,9 @@ class LoRALayerWeights:
         self.embeddings_tensor = embeddings_tensor
 
         if scaling is None:
-            self.scaling = self.lora_alpha / self.rank
+            self.scaling = (self.lora_alpha / self.rank) * runtime_scaling
         else:
-            self.scaling = scaling
+            self.scaling = scaling * runtime_scaling
 
     def optimize(self) -> "LoRALayerWeights":
         """Optimize the LoRA by merging the scaling into lora_b."""

+ 15 - 1
aphrodite/lora/models.py

@@ -60,6 +60,7 @@ class LoRAModel(AdapterModel):
         rank: int,
         loras: Dict[str, LoRALayerWeights],
         scaling_factor: Optional[float] = None,
+        runtime_scaling: Optional[float] = 1.0,
     ) -> None:
         """
         Args:
@@ -73,10 +74,23 @@ class LoRAModel(AdapterModel):
         # Scaling factor for long context lora model. None if it is not
         # fine tuned for the long context.
         self.scaling_factor = scaling_factor
+        self.runtime_scaling = runtime_scaling
         assert (lora_model_id >
                 0), f"a valid lora id should be greater than 0, got {self.id}"
         self.rank = rank
-        self.loras: Dict[str, LoRALayerWeights] = loras
+        self.loras = {
+            name: LoRALayerWeights(
+                module_name=weight.module_name,
+                rank=weight.rank,
+                lora_alpha=weight.lora_alpha,
+                lora_a=weight.lora_a,
+                lora_b=weight.lora_b,
+                embeddings_tensor=weight.embeddings_tensor,
+                scaling=weight.scaling,
+                runtime_scaling=runtime_scaling
+            )
+            for name, weight in loras.items()
+        }
 
     def clone(self, lora_model_id: int) -> "LoRAModel":
         """Return a copy of the object with different ids.

+ 5 - 0
aphrodite/lora/request.py

@@ -26,6 +26,7 @@ class LoRARequest(
     lora_name: str
     lora_int_id: int
     lora_path: str = ""
+    scaling_factor: Optional[float] = 1.0
     lora_local_path: Optional[str] = msgspec.field(default=None)
     long_lora_max_len: Optional[int] = None
     __hash__ = AdapterRequest.__hash__
@@ -44,6 +45,10 @@ class LoRARequest(
         # Ensure lora_path is not empty
         assert self.lora_path, "lora_path cannot be empty"
 
+        # Scaling factor must be non-negative
+        assert self.scaling_factor is None or self.scaling_factor >= 0, \
+            "scaling_factor must be non-negative"
+
     @property
     def adapter_id(self):
         return self.lora_int_id