|
@@ -60,6 +60,7 @@ class LoRAModel(AdapterModel):
|
|
|
rank: int,
|
|
|
loras: Dict[str, LoRALayerWeights],
|
|
|
scaling_factor: Optional[float] = None,
|
|
|
+ runtime_scaling: Optional[float] = 1.0,
|
|
|
) -> None:
|
|
|
"""
|
|
|
Args:
|
|
@@ -73,10 +74,23 @@ class LoRAModel(AdapterModel):
|
|
|
# Scaling factor for long context lora model. None if it is not
|
|
|
# fine tuned for the long context.
|
|
|
self.scaling_factor = scaling_factor
|
|
|
+ self.runtime_scaling = runtime_scaling
|
|
|
assert (lora_model_id >
|
|
|
0), f"a valid lora id should be greater than 0, got {self.id}"
|
|
|
self.rank = rank
|
|
|
- self.loras: Dict[str, LoRALayerWeights] = loras
|
|
|
+ self.loras = {
|
|
|
+ name: LoRALayerWeights(
|
|
|
+ module_name=weight.module_name,
|
|
|
+ rank=weight.rank,
|
|
|
+ lora_alpha=weight.lora_alpha,
|
|
|
+ lora_a=weight.lora_a,
|
|
|
+ lora_b=weight.lora_b,
|
|
|
+ embeddings_tensor=weight.embeddings_tensor,
|
|
|
+ scaling=weight.scaling,
|
|
|
+ runtime_scaling=runtime_scaling
|
|
|
+ )
|
|
|
+ for name, weight in loras.items()
|
|
|
+ }
|
|
|
|
|
|
def clone(self, lora_model_id: int) -> "LoRAModel":
|
|
|
"""Return a copy of the object with different ids.
|