utils.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import os
  2. from typing import List, Optional, Set, Tuple, Type
  3. import huggingface_hub
  4. from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
  5. HFValidationError, RepositoryNotFoundError)
  6. from loguru import logger
  7. from torch import nn
  8. from transformers import PretrainedConfig
  9. from aphrodite.common.config import LoRAConfig
  10. from aphrodite.lora.fully_sharded_layers import (
  11. ColumnParallelLinearWithShardedLoRA,
  12. MergedColumnParallelLinearWithShardedLoRA,
  13. MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
  14. RowParallelLinearWithShardedLoRA)
  15. # being imported for _all_lora_classes below
  16. # yapf conflicts with isort for this block
  17. # yapf: disable
  18. from aphrodite.lora.layers import (BaseLayerWithLoRA,
  19. ColumnParallelLinearWithLoRA,
  20. LinearScalingRotaryEmbeddingWithLora,
  21. LogitsProcessorWithLoRA,
  22. MergedColumnParallelLinearWithLoRA,
  23. MergedQKVParallelLinearWithLora,
  24. ModulesToSaveWrapper,
  25. QKVParallelLinearWithLora,
  26. ReplicatedLinearWithLoRA,
  27. RowParallelLinearWithLoRA,
  28. VocabParallelEmbeddingWithLoRA)
  29. # yapf: enable
  30. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  31. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  32. _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
  33. VocabParallelEmbeddingWithLoRA,
  34. ColumnParallelLinearWithLoRA,
  35. MergedColumnParallelLinearWithLoRA,
  36. QKVParallelLinearWithLora,
  37. MergedQKVParallelLinearWithLora,
  38. RowParallelLinearWithLoRA,
  39. ReplicatedLinearWithLoRA,
  40. LogitsProcessorWithLoRA,
  41. ColumnParallelLinearWithShardedLoRA,
  42. QKVParallelLinearWithShardedLora,
  43. MergedColumnParallelLinearWithShardedLoRA,
  44. MergedQKVParallelLinearWithShardedLora,
  45. RowParallelLinearWithShardedLoRA,
  46. LinearScalingRotaryEmbeddingWithLora,
  47. ModulesToSaveWrapper,
  48. }
  49. def from_layer(layer: nn.Module,
  50. max_loras: int,
  51. lora_config: LoRAConfig,
  52. packed_modules_list: List,
  53. model_config: Optional[PretrainedConfig] = None) -> nn.Module:
  54. for lora_cls in _all_lora_classes:
  55. # specifying kwargs so they can be easily accessed in decorator
  56. if lora_cls.can_replace_layer(source_layer=layer,
  57. lora_config=lora_config,
  58. packed_modules_list=packed_modules_list,
  59. model_config=model_config):
  60. ret = lora_cls(layer)
  61. ret.create_lora_weights(max_loras, lora_config, model_config)
  62. return ret
  63. return layer
  64. def from_layer_logits_processor(
  65. layer: LogitsProcessor,
  66. lm_head: ParallelLMHead,
  67. max_loras: int,
  68. lora_config: LoRAConfig,
  69. model_config: Optional[PretrainedConfig] = None,
  70. ) -> LogitsProcessorWithLoRA:
  71. ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
  72. lm_head.weight.dtype, lm_head.weight.device,
  73. lm_head.get_sharded_to_full_mapping())
  74. ret.create_lora_weights(max_loras, lora_config, model_config)
  75. return ret
  76. def replace_submodule(model: nn.Module, module_name: str,
  77. new_module: nn.Module) -> nn.Module:
  78. """Replace a submodule in a model with a new module."""
  79. parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
  80. target_name = module_name.split(".")[-1]
  81. setattr(parent, target_name, new_module)
  82. return new_module
  83. def parse_fine_tuned_lora_name(
  84. name: str,
  85. enable_lora_modules_to_save: bool = False
  86. ) -> Tuple[str, Optional[bool]]:
  87. """Parse the name of lora weights.
  88. args:
  89. name: the name of the fine-tuned LoRA, e.g.
  90. base_model.model.dense1.weight
  91. return:
  92. Tuple(module_name, is_lora_a):
  93. module_name: the name of the module, e.g. model.dense1,
  94. is_lora_a whether the tensor is lora_a or lora_b.
  95. None - if tensor is for ModulesToSaveWrapper
  96. """
  97. parts = name.split(".")
  98. if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
  99. if parts[-1] == "weight":
  100. if parts[-2] == "lora_A" or parts[-2] == "lora_B":
  101. return ".".join(parts[2:-2]), parts[-2] == "lora_A"
  102. if parts[-2] in ModulesToSaveWrapper.implemented_layers:
  103. if not enable_lora_modules_to_save:
  104. error_msg = f"""enable_lora_modules_to_save is False,
  105. but found tensor name {name} in LoRA checkpoint.
  106. Set enable_lora_modules_to_save=True to process
  107. lm_head and embed_tokens as fully trained tensors"""
  108. raise ValueError(error_msg)
  109. return '.'.join(parts[2:-1]), None
  110. elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
  111. return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
  112. raise ValueError(f"{name} is unsupported LoRA weight")
  113. def get_adapter_absolute_path(lora_path: str) -> str:
  114. """
  115. Resolves the given lora_path to an absolute local path.
  116. If the lora_path is identified as a Hugging Face model identifier,
  117. it will download the model and return the local snapshot path.
  118. Otherwise, it treats the lora_path as a local file path and
  119. converts it to an absolute path.
  120. Parameters:
  121. lora_path (str): The path to the lora model, which can be an absolute path,
  122. a relative path, or a Hugging Face model identifier.
  123. Returns:
  124. str: The resolved absolute local path to the lora model.
  125. """
  126. # Check if the path is an absolute path. Return it no matter exists or not.
  127. if os.path.isabs(lora_path):
  128. return lora_path
  129. # If the path starts with ~, expand the user home directory.
  130. if lora_path.startswith('~'):
  131. return os.path.expanduser(lora_path)
  132. # Check if the expanded relative path exists locally.
  133. if os.path.exists(lora_path):
  134. return os.path.abspath(lora_path)
  135. # If the path does not exist locally, assume it's a Hugging Face repo.
  136. try:
  137. local_snapshot_path = huggingface_hub.snapshot_download(
  138. repo_id=lora_path)
  139. except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
  140. HFValidationError):
  141. # Handle errors that may occur during the download
  142. # Return original path instead instead of throwing error here
  143. logger.exception("Error downloading the HuggingFace model")
  144. return lora_path
  145. return local_snapshot_path