123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- import os
- from typing import List, Optional, Set, Tuple, Type
- import huggingface_hub
- from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
- HFValidationError, RepositoryNotFoundError)
- from loguru import logger
- from torch import nn
- from transformers import PretrainedConfig
- from aphrodite.common.config import LoRAConfig
- from aphrodite.lora.fully_sharded_layers import (
- ColumnParallelLinearWithShardedLoRA,
- MergedColumnParallelLinearWithShardedLoRA,
- MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
- RowParallelLinearWithShardedLoRA)
- # being imported for _all_lora_classes below
- # yapf conflicts with isort for this block
- # yapf: disable
- from aphrodite.lora.layers import (BaseLayerWithLoRA,
- ColumnParallelLinearWithLoRA,
- LinearScalingRotaryEmbeddingWithLora,
- LogitsProcessorWithLoRA,
- MergedColumnParallelLinearWithLoRA,
- MergedQKVParallelLinearWithLora,
- ModulesToSaveWrapper,
- QKVParallelLinearWithLora,
- ReplicatedLinearWithLoRA,
- RowParallelLinearWithLoRA,
- VocabParallelEmbeddingWithLoRA)
- # yapf: enable
- from aphrodite.modeling.layers.logits_processor import LogitsProcessor
- from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
- _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
- VocabParallelEmbeddingWithLoRA,
- ColumnParallelLinearWithLoRA,
- MergedColumnParallelLinearWithLoRA,
- QKVParallelLinearWithLora,
- MergedQKVParallelLinearWithLora,
- RowParallelLinearWithLoRA,
- ReplicatedLinearWithLoRA,
- LogitsProcessorWithLoRA,
- ColumnParallelLinearWithShardedLoRA,
- QKVParallelLinearWithShardedLora,
- MergedColumnParallelLinearWithShardedLoRA,
- MergedQKVParallelLinearWithShardedLora,
- RowParallelLinearWithShardedLoRA,
- LinearScalingRotaryEmbeddingWithLora,
- ModulesToSaveWrapper,
- }
- def from_layer(layer: nn.Module,
- max_loras: int,
- lora_config: LoRAConfig,
- packed_modules_list: List,
- model_config: Optional[PretrainedConfig] = None) -> nn.Module:
- for lora_cls in _all_lora_classes:
- # specifying kwargs so they can be easily accessed in decorator
- if lora_cls.can_replace_layer(source_layer=layer,
- lora_config=lora_config,
- packed_modules_list=packed_modules_list,
- model_config=model_config):
- ret = lora_cls(layer)
- ret.create_lora_weights(max_loras, lora_config, model_config)
- return ret
- return layer
- def from_layer_logits_processor(
- layer: LogitsProcessor,
- lm_head: ParallelLMHead,
- max_loras: int,
- lora_config: LoRAConfig,
- model_config: Optional[PretrainedConfig] = None,
- ) -> LogitsProcessorWithLoRA:
- ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
- lm_head.weight.dtype, lm_head.weight.device,
- lm_head.get_sharded_to_full_mapping())
- ret.create_lora_weights(max_loras, lora_config, model_config)
- return ret
- def replace_submodule(model: nn.Module, module_name: str,
- new_module: nn.Module) -> nn.Module:
- """Replace a submodule in a model with a new module."""
- parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
- target_name = module_name.split(".")[-1]
- setattr(parent, target_name, new_module)
- return new_module
- def parse_fine_tuned_lora_name(
- name: str,
- enable_lora_modules_to_save: bool = False
- ) -> Tuple[str, Optional[bool]]:
- """Parse the name of lora weights.
- args:
- name: the name of the fine-tuned LoRA, e.g.
- base_model.model.dense1.weight
- return:
- Tuple(module_name, is_lora_a):
- module_name: the name of the module, e.g. model.dense1,
- is_lora_a whether the tensor is lora_a or lora_b.
- None - if tensor is for ModulesToSaveWrapper
- """
- parts = name.split(".")
- if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
- if parts[-1] == "weight":
- if parts[-2] == "lora_A" or parts[-2] == "lora_B":
- return ".".join(parts[2:-2]), parts[-2] == "lora_A"
- if parts[-2] in ModulesToSaveWrapper.implemented_layers:
- if not enable_lora_modules_to_save:
- error_msg = f"""enable_lora_modules_to_save is False,
- but found tensor name {name} in LoRA checkpoint.
- Set enable_lora_modules_to_save=True to process
- lm_head and embed_tokens as fully trained tensors"""
- raise ValueError(error_msg)
- return '.'.join(parts[2:-1]), None
- elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
- return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
- raise ValueError(f"{name} is unsupported LoRA weight")
- def get_adapter_absolute_path(lora_path: str) -> str:
- """
- Resolves the given lora_path to an absolute local path.
- If the lora_path is identified as a Hugging Face model identifier,
- it will download the model and return the local snapshot path.
- Otherwise, it treats the lora_path as a local file path and
- converts it to an absolute path.
- Parameters:
- lora_path (str): The path to the lora model, which can be an absolute path,
- a relative path, or a Hugging Face model identifier.
- Returns:
- str: The resolved absolute local path to the lora model.
- """
- # Check if the path is an absolute path. Return it no matter exists or not.
- if os.path.isabs(lora_path):
- return lora_path
- # If the path starts with ~, expand the user home directory.
- if lora_path.startswith('~'):
- return os.path.expanduser(lora_path)
- # Check if the expanded relative path exists locally.
- if os.path.exists(lora_path):
- return os.path.abspath(lora_path)
- # If the path does not exist locally, assume it's a Hugging Face repo.
- try:
- local_snapshot_path = huggingface_hub.snapshot_download(
- repo_id=lora_path)
- except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
- HFValidationError):
- # Handle errors that may occur during the download
- # Return original path instead instead of throwing error here
- logger.exception("Error downloading the HuggingFace model")
- return lora_path
- return local_snapshot_path
|