utils.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import logging
  2. from typing import Tuple
  3. from torch import nn
  4. logger = logging.getLogger(__name__)
  5. def replace_submodule(model: nn.Module, module_name: str,
  6. new_module: nn.Module) -> nn.Module:
  7. """Replace a submodule in a model with a new module."""
  8. parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
  9. target_name = module_name.split(".")[-1]
  10. setattr(parent, target_name, new_module)
  11. return new_module
  12. def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
  13. """Parse the name of lora weights.
  14. args:
  15. name: the name of the fine-tuned LoRA, e.g.
  16. base_model.model.dense1.weight
  17. return:
  18. Tuple(module_name, is_lora_a):
  19. module_name: the name of the module, e.g. model.dense1,
  20. is_lora_a whether the tensor is lora_a or lora_b.
  21. """
  22. parts = name.split(".")
  23. assert parts[0] == "base_model"
  24. assert parts[1] == "model"
  25. if parts[-1] == "weight":
  26. if parts[-2] == "lora_A" or parts[-2] == "lora_B":
  27. return ".".join(parts[2:-2]), parts[-2] == "lora_A"
  28. else:
  29. return None
  30. if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
  31. return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
  32. raise ValueError(f"{name} is unsupported format")