|
@@ -92,13 +92,12 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
|
|
is_lora_a whether the tensor is lora_a or lora_b.
|
|
|
"""
|
|
|
parts = name.split(".")
|
|
|
- assert parts[0] == "base_model"
|
|
|
- assert parts[1] == "model"
|
|
|
- if parts[-1] == "weight":
|
|
|
- assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
|
|
|
- return ".".join(parts[2:-2]), parts[-2] == "lora_A"
|
|
|
|
|
|
- if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
|
|
- return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
|
|
+ if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
|
|
|
+ if parts[-1] == "weight":
|
|
|
+ if parts[-2] == "lora_A" or parts[-2] == "lora_B":
|
|
|
+ return ".".join(parts[2:-2]), parts[-2] == "lora_A"
|
|
|
+ elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
|
|
+ return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
|
|
|
|
|
- raise ValueError(f"{name} is unsupported format")
|
|
|
+ raise ValueError(f"{name} is unsupported LoRA weight")
|