utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from typing import List, Optional, Set, Tuple, Type
  2. from torch import nn
  3. from transformers import PretrainedConfig
  4. from aphrodite.common.config import LoRAConfig
  5. from aphrodite.lora.fully_sharded_layers import (
  6. ColumnParallelLinearWithShardedLoRA,
  7. MergedColumnParallelLinearWithShardedLoRA,
  8. MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
  9. # being imported for _all_lora_classes below
  10. # yapf conflicts with isort for this block
  11. # yapf: disable
  12. from aphrodite.lora.layers import (BaseLayerWithLoRA,
  13. ColumnParallelLinearWithLoRA,
  14. LinearScalingRotaryEmbeddingWithLora,
  15. LogitsProcessorWithLoRA,
  16. MergedColumnParallelLinearWithLoRA,
  17. MergedQKVParallelLinearWithLora,
  18. QKVParallelLinearWithLora,
  19. RowParallelLinearWithLoRA,
  20. VocabParallelEmbeddingWithLoRA)
  21. # yapf: enable
  22. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  23. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  24. _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
  25. VocabParallelEmbeddingWithLoRA,
  26. ColumnParallelLinearWithLoRA,
  27. MergedColumnParallelLinearWithLoRA,
  28. QKVParallelLinearWithLora,
  29. MergedQKVParallelLinearWithLora,
  30. RowParallelLinearWithLoRA,
  31. LogitsProcessorWithLoRA,
  32. ColumnParallelLinearWithShardedLoRA,
  33. MergedColumnParallelLinearWithShardedLoRA,
  34. MergedQKVParallelLinearWithShardedLora,
  35. RowParallelLinearWithShardedLoRA,
  36. LinearScalingRotaryEmbeddingWithLora,
  37. }
  38. def from_layer(layer: nn.Module,
  39. max_loras: int,
  40. lora_config: LoRAConfig,
  41. packed_modules_list: List,
  42. model_config: Optional[PretrainedConfig] = None) -> nn.Module:
  43. for lora_cls in _all_lora_classes:
  44. # specifying kwargs so they can be easily accessed in decorator
  45. if lora_cls.can_replace_layer(source_layer=layer,
  46. lora_config=lora_config,
  47. packed_modules_list=packed_modules_list,
  48. model_config=model_config):
  49. ret = lora_cls(layer)
  50. ret.create_lora_weights(max_loras, lora_config, model_config)
  51. return ret
  52. return layer
  53. def from_layer_logits_processor(
  54. layer: LogitsProcessor,
  55. lm_head: ParallelLMHead,
  56. max_loras: int,
  57. lora_config: LoRAConfig,
  58. model_config: Optional[PretrainedConfig] = None,
  59. ) -> LogitsProcessorWithLoRA:
  60. ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
  61. lm_head.weight.dtype, lm_head.weight.device,
  62. lm_head.get_sharded_to_full_mapping())
  63. ret.create_lora_weights(max_loras, lora_config, model_config)
  64. return ret
  65. def replace_submodule(model: nn.Module, module_name: str,
  66. new_module: nn.Module) -> nn.Module:
  67. """Replace a submodule in a model with a new module."""
  68. parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
  69. target_name = module_name.split(".")[-1]
  70. setattr(parent, target_name, new_module)
  71. return new_module
  72. def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
  73. """Parse the name of lora weights.
  74. args:
  75. name: the name of the fine-tuned LoRA, e.g.
  76. base_model.model.dense1.weight
  77. return:
  78. Tuple(module_name, is_lora_a):
  79. module_name: the name of the module, e.g. model.dense1,
  80. is_lora_a whether the tensor is lora_a or lora_b.
  81. """
  82. parts = name.split(".")
  83. if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
  84. if parts[-1] == "weight":
  85. if parts[-2] == "lora_A" or parts[-2] == "lora_B":
  86. return ".".join(parts[2:-2]), parts[-2] == "lora_A"
  87. elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
  88. return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
  89. raise ValueError(f"{name} is unsupported LoRA weight")