utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from typing import List, Optional, Set, Tuple, Type
  2. from torch import nn
  3. from transformers import PretrainedConfig
  4. from aphrodite.common.config import LoRAConfig
  5. from aphrodite.lora.fully_sharded_layers import (
  6. ColumnParallelLinearWithShardedLoRA,
  7. MergedColumnParallelLinearWithShardedLoRA,
  8. MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
  9. RowParallelLinearWithShardedLoRA)
  10. # being imported for _all_lora_classes below
  11. # yapf conflicts with isort for this block
  12. # yapf: disable
  13. from aphrodite.lora.layers import (BaseLayerWithLoRA,
  14. ColumnParallelLinearWithLoRA,
  15. LinearScalingRotaryEmbeddingWithLora,
  16. LogitsProcessorWithLoRA,
  17. MergedColumnParallelLinearWithLoRA,
  18. MergedQKVParallelLinearWithLora,
  19. QKVParallelLinearWithLora,
  20. RowParallelLinearWithLoRA,
  21. VocabParallelEmbeddingWithLoRA)
  22. # yapf: enable
  23. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  24. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  25. _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
  26. VocabParallelEmbeddingWithLoRA,
  27. ColumnParallelLinearWithLoRA,
  28. MergedColumnParallelLinearWithLoRA,
  29. QKVParallelLinearWithLora,
  30. MergedQKVParallelLinearWithLora,
  31. RowParallelLinearWithLoRA,
  32. LogitsProcessorWithLoRA,
  33. ColumnParallelLinearWithShardedLoRA,
  34. QKVParallelLinearWithShardedLora,
  35. MergedColumnParallelLinearWithShardedLoRA,
  36. MergedQKVParallelLinearWithShardedLora,
  37. RowParallelLinearWithShardedLoRA,
  38. LinearScalingRotaryEmbeddingWithLora,
  39. }
  40. def from_layer(layer: nn.Module,
  41. max_loras: int,
  42. lora_config: LoRAConfig,
  43. packed_modules_list: List,
  44. model_config: Optional[PretrainedConfig] = None) -> nn.Module:
  45. for lora_cls in _all_lora_classes:
  46. # specifying kwargs so they can be easily accessed in decorator
  47. if lora_cls.can_replace_layer(source_layer=layer,
  48. lora_config=lora_config,
  49. packed_modules_list=packed_modules_list,
  50. model_config=model_config):
  51. ret = lora_cls(layer)
  52. ret.create_lora_weights(max_loras, lora_config, model_config)
  53. return ret
  54. return layer
  55. def from_layer_logits_processor(
  56. layer: LogitsProcessor,
  57. lm_head: ParallelLMHead,
  58. max_loras: int,
  59. lora_config: LoRAConfig,
  60. model_config: Optional[PretrainedConfig] = None,
  61. ) -> LogitsProcessorWithLoRA:
  62. ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
  63. lm_head.weight.dtype, lm_head.weight.device,
  64. lm_head.get_sharded_to_full_mapping())
  65. ret.create_lora_weights(max_loras, lora_config, model_config)
  66. return ret
  67. def replace_submodule(model: nn.Module, module_name: str,
  68. new_module: nn.Module) -> nn.Module:
  69. """Replace a submodule in a model with a new module."""
  70. parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
  71. target_name = module_name.split(".")[-1]
  72. setattr(parent, target_name, new_module)
  73. return new_module
  74. def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
  75. """Parse the name of lora weights.
  76. args:
  77. name: the name of the fine-tuned LoRA, e.g.
  78. base_model.model.dense1.weight
  79. return:
  80. Tuple(module_name, is_lora_a):
  81. module_name: the name of the module, e.g. model.dense1,
  82. is_lora_a whether the tensor is lora_a or lora_b.
  83. """
  84. parts = name.split(".")
  85. if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
  86. if parts[-1] == "weight":
  87. if parts[-2] == "lora_A" or parts[-2] == "lora_B":
  88. return ".".join(parts[2:-2]), parts[-2] == "lora_A"
  89. elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
  90. return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
  91. raise ValueError(f"{name} is unsupported LoRA weight")