utils.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from typing import Tuple
  2. from torch import nn
  3. def replace_submodule(model: nn.Module, module_name: str,
  4. new_module: nn.Module) -> nn.Module:
  5. """Replace a submodule in a model with a new module."""
  6. parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
  7. target_name = module_name.split(".")[-1]
  8. setattr(parent, target_name, new_module)
  9. return new_module
  10. def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
  11. """Parse the name of lora weights.
  12. args:
  13. name: the name of the fine-tuned LoRA, e.g.
  14. base_model.model.dense1.weight
  15. return:
  16. Tuple(module_name, is_lora_a):
  17. module_name: the name of the module, e.g. model.dense1,
  18. is_lora_a whether the tensor is lora_a or lora_b.
  19. """
  20. parts = name.split(".")
  21. assert parts[0] == "base_model"
  22. assert parts[1] == "model"
  23. if parts[-1] == "weight":
  24. if parts[-2] == "lora_A" or parts[-2] == "lora_B":
  25. return ".".join(parts[2:-2]), parts[-2] == "lora_A"
  26. else:
  27. # Handle the case where the tensor name is
  28. # "base_model.model.lm_head.weight"
  29. return ".".join(parts[2:]), True
  30. if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
  31. return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
  32. raise ValueError(f"{name} is unsupported format")