123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678 |
- """Utilities for downloading and initializing model weights."""
- import fnmatch
- import glob
- import hashlib
- import json
- import os
- import tempfile
- from collections import defaultdict
- from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
- import filelock
- import gguf
- import huggingface_hub.constants
- import numpy as np
- import torch
- from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
- from loguru import logger
- from safetensors.torch import load_file, safe_open, save_file
- from tqdm.auto import tqdm
- from aphrodite.common.config import LoadConfig, ModelConfig
- from aphrodite.common.utils import print_warning_once
- from aphrodite.distributed import get_tensor_model_parallel_rank
- from aphrodite.platforms import current_platform
- from aphrodite.quantization import QuantizationConfig, get_quantization_config
- from aphrodite.quantization.schema import QuantParamSchema
- # use system-level temp directory for file locks, so that multiple users
- # can share the same lock without error.
- # lock files in the temp directory will be automatically deleted when the
- # system reboots, so users will not complain about annoying lock files
- temp_dir = tempfile.gettempdir()
- def enable_hf_transfer():
- """automatically activates hf_transfer
- """
- if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
- try:
- # enable hf hub transfer if available
- import hf_transfer # type: ignore # noqa
- huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
- except ImportError:
- pass
- enable_hf_transfer()
- class DisabledTqdm(tqdm):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs, disable=True)
- def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
- lock_dir = cache_dir or temp_dir
- os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
- model_name = model_name_or_path.replace("/", "-")
- hash_name = hashlib.sha256(model_name.encode()).hexdigest()
- # add hash to avoid conflict with old users' lock files
- lock_file_name = hash_name + model_name + ".lock"
- # mode 0o666 is required for the filelock to be shared across users
- lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
- mode=0o666)
- return lock
- def _shared_pointers(tensors):
- ptrs = defaultdict(list)
- for k, v in tensors.items():
- ptrs[v.data_ptr()].append(k)
- failing = []
- for _, names in ptrs.items():
- if len(names) > 1:
- failing.append(names)
- return failing
- def convert_bin_to_safetensor_file(
- pt_filename: str,
- sf_filename: str,
- ) -> None:
- loaded = torch.load(pt_filename, map_location="cpu")
- if "state_dict" in loaded:
- loaded = loaded["state_dict"]
- shared = _shared_pointers(loaded)
- for shared_weights in shared:
- for name in shared_weights[1:]:
- loaded.pop(name)
- # For tensors to be contiguous
- loaded = {k: v.contiguous() for k, v in loaded.items()}
- dirname = os.path.dirname(sf_filename)
- os.makedirs(dirname, exist_ok=True)
- save_file(loaded, sf_filename, metadata={"format": "pt"})
- # check file size
- sf_size = os.stat(sf_filename).st_size
- pt_size = os.stat(pt_filename).st_size
- if (sf_size - pt_size) / pt_size > 0.01:
- raise RuntimeError(f"""The file size different is more than 1%:
- - {sf_filename}: {sf_size}
- - {pt_filename}: {pt_size}
- """)
- # check if the tensors are the same
- reloaded = load_file(sf_filename)
- for k in loaded:
- pt_tensor = loaded[k]
- sf_tensor = reloaded[k]
- if not torch.equal(pt_tensor, sf_tensor):
- raise RuntimeError(f"The output tensors do not match for key {k}")
- # TODO: Move this to other place.
- def get_quant_config(model_config: ModelConfig,
- load_config: LoadConfig) -> QuantizationConfig:
- quant_cls = get_quantization_config(model_config.quantization)
- # GGUF doesn't have config file
- if model_config.quantization == "gguf":
- return quant_cls.from_config({})
- # Read the quantization config from the HF model config, if available.
- hf_quant_config = getattr(model_config.hf_config, "quantization_config",
- None)
- # some vision model may keep quantization_config in their text_config
- hf_text_config = getattr(model_config.hf_config, "text_config", None)
- if hf_quant_config is None and hf_text_config is not None:
- hf_quant_config = getattr(hf_text_config, "quantization_config", None)
- if hf_quant_config is None:
- # compressed-tensors uses a compressions_config
- hf_quant_config = getattr(model_config.hf_config, "compression_config",
- None)
- if hf_quant_config is not None:
- return quant_cls.from_config(hf_quant_config)
- # In case of bitsandbytes/QLoRA, get quant config from the adapter model.
- if model_config.quantization == "bitsandbytes":
- if (not load_config.model_loader_extra_config
- or "qlora_adapter_name_or_path"
- not in load_config.model_loader_extra_config):
- return quant_cls.from_config({"adapter_name_or_path": ""})
- model_name_or_path = load_config.model_loader_extra_config[
- "qlora_adapter_name_or_path"]
- else:
- model_name_or_path = model_config.model
- is_local = os.path.isdir(model_name_or_path)
- if not is_local:
- # Download the config files.
- with get_lock(model_name_or_path, load_config.download_dir):
- hf_folder = snapshot_download(
- model_name_or_path,
- revision=model_config.revision,
- allow_patterns="*.json",
- cache_dir=load_config.download_dir,
- local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
- tqdm_class=DisabledTqdm,
- )
- else:
- hf_folder = model_name_or_path
- possible_config_filenames = quant_cls.get_config_filenames()
- # If the quantization config is not found, use the default config.
- if not possible_config_filenames:
- return quant_cls()
- config_files = glob.glob(os.path.join(hf_folder, "*.json"))
- quant_config_files = [
- f for f in config_files if any(
- f.endswith(x) for x in possible_config_filenames)
- ]
- if len(quant_config_files) == 0:
- raise ValueError(
- f"Cannot find the config file for {model_config.quantization}")
- if len(quant_config_files) > 1:
- raise ValueError(
- f"Found multiple config files for {model_config.quantization}: "
- f"{quant_config_files}")
- quant_config_file = quant_config_files[0]
- with open(quant_config_file, "r") as f:
- config = json.load(f)
- if model_config.quantization == "bitsandbytes":
- config["adapter_name_or_path"] = model_name_or_path
- elif model_config.quantization == "modelopt":
- if config["producer"]["name"] == "modelopt":
- return quant_cls.from_config(config)
- else:
- raise ValueError(
- f"Unsupported quantization config"
- f" found for {model_config.quantization} in {f}.")
- return quant_cls.from_config(config)
- def download_weights_from_hf(
- model_name_or_path: str,
- cache_dir: Optional[str],
- allow_patterns: List[str],
- revision: Optional[str] = None,
- ignore_patterns: Optional[Union[str, List[str]]] = None,
- ) -> str:
- """Download model weights from Hugging Face Hub.
-
- Args:
- model_name_or_path (str): The model name or path.
- cache_dir (Optional[str]): The cache directory to store the model
- weights. If None, will use HF defaults.
- allow_patterns (List[str]): The allowed patterns for the
- weight files. Files matched by any of the patterns will be
- downloaded.
- revision (Optional[str]): The revision of the model.
- ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
- filter out the weight files. Files matched by any of the patterns
- will be ignored.
- Returns:
- str: The path to the downloaded model weights.
- """
- if not huggingface_hub.constants.HF_HUB_OFFLINE:
- # Before we download we look at that is available:
- fs = HfFileSystem()
- file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
- # depending on what is available we download different things
- for pattern in allow_patterns:
- matching = fnmatch.filter(file_list, pattern)
- if len(matching) > 0:
- allow_patterns = [pattern]
- break
- rank = (get_tensor_model_parallel_rank() if
- torch.distributed.is_initialized() else 0)
- if rank == 0:
- logger.info(f"Using model weights format {allow_patterns}")
- # Use file lock to prevent multiple processes from
- # downloading the same model weights at the same time.
- with get_lock(model_name_or_path, cache_dir):
- hf_folder = snapshot_download(
- model_name_or_path,
- allow_patterns=allow_patterns,
- ignore_patterns=ignore_patterns,
- cache_dir=cache_dir,
- tqdm_class=DisabledTqdm,
- revision=revision,
- local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
- )
- return hf_folder
- def download_safetensors_index_file_from_hf(
- model_name_or_path: str,
- index_file: str,
- cache_dir: Optional[str],
- revision: Optional[str] = None,
- ) -> None:
- """Download hf safetensors index file from Hugging Face Hub.
- Args:
- model_name_or_path (str): The model name or path.
- cache_dir (Optional[str]): The cache directory to store the model
- weights. If None, will use HF defaults.
- revision (Optional[str]): The revision of the model.
- """
- # Use file lock to prevent multiple processes from
- # downloading the same model weights at the same time.
- with get_lock(model_name_or_path, cache_dir):
- try:
- # Download the safetensors index file.
- hf_hub_download(
- repo_id=model_name_or_path,
- filename=index_file,
- cache_dir=cache_dir,
- revision=revision,
- local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
- )
- # If file not found on remote or locally, we should not fail since
- # only some models will have index_file.
- except huggingface_hub.utils.EntryNotFoundError:
- logger.info(f"No {index_file} found in remote.")
- except huggingface_hub.utils.LocalEntryNotFoundError:
- logger.info(f"No {index_file} found in local cache.")
- # For models like Mistral-7B-v0.3, there are both sharded
- # safetensors files and a consolidated safetensors file.
- # Passing both of these to the weight loader functionality breaks.
- # So, we use the index_file to
- # look up which safetensors files should be used.
- def filter_duplicate_safetensors_files(hf_weights_files: List[str],
- hf_folder: str,
- index_file: str) -> List[str]:
- # model.safetensors.index.json is a mapping from keys in the
- # torch state_dict to safetensors file holding that weight.
- index_file_name = os.path.join(hf_folder, index_file)
- if not os.path.isfile(index_file_name):
- return hf_weights_files
- # Iterate through the weight_map (weight_name: safetensors files)
- # to identify weights that we should use.
- with open(index_file_name, "r") as f:
- weight_map = json.load(f)["weight_map"]
- weight_files_in_index = set()
- for weight_name in weight_map:
- weight_files_in_index.add(
- os.path.join(hf_folder, weight_map[weight_name]))
- # Filter out any fields that are not found in the index file.
- hf_weights_files = [
- f for f in hf_weights_files if f in weight_files_in_index
- ]
- return hf_weights_files
- def filter_files_not_needed_for_inference(
- hf_weights_files: List[str]) -> List[str]:
- """
- Exclude files that are not needed for inference.
- See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
- """
- blacklist = [
- "training_args.bin",
- "optimizer.bin",
- "optimizer.pt",
- "scheduler.pt",
- "scaler.pt",
- ]
- hf_weights_files = [
- f for f in hf_weights_files
- if not any(f.endswith(x) for x in blacklist)
- ]
- return hf_weights_files
- def np_cache_weights_iterator(
- model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
- hf_weights_files: List[str]
- ) -> Generator[Tuple[str, torch.Tensor], None, None]:
- """Iterate over the weights in the model np files.
- Will dump the model weights to numpy files if they are not already dumped.
- """
- enable_tqdm = False #not torch.distributed.is_initialized(
- #) or torch.distributed.get_rank() == 0
- # Convert the model weights from torch tensors to numpy arrays for
- # faster loading.
- np_folder = os.path.join(hf_folder, "np")
- os.makedirs(np_folder, exist_ok=True)
- weight_names_file = os.path.join(np_folder, "weight_names.json")
- # Use file lock to prevent multiple processes from
- # dumping the same model weights to numpy at the same time.
- with get_lock(model_name_or_path, cache_dir):
- if not os.path.exists(weight_names_file):
- weight_names = []
- for bin_file in tqdm(
- hf_weights_files,
- desc="Loading np_cache checkpoint shards",
- disable=not enable_tqdm,
- ):
- state = torch.load(bin_file, map_location="cpu")
- for name, param in state.items():
- param_path = os.path.join(np_folder, name)
- with open(param_path, "wb") as f:
- np.save(f, param.cpu().detach().numpy())
- weight_names.append(name)
- with open(weight_names_file, "w") as f:
- json.dump(weight_names, f)
- with open(weight_names_file, "r") as f:
- weight_names = json.load(f)
- for name in weight_names:
- param_path = os.path.join(np_folder, name)
- with open(param_path, "rb") as f:
- param = np.load(f)
- yield name, torch.from_numpy(param)
- def safetensors_weights_iterator(
- hf_weights_files: List[str]
- ) -> Generator[Tuple[str, torch.Tensor], None, None]:
- """Iterate over the weights in the model safetensor files."""
- enable_tqdm = False #not torch.distributed.is_initialized(
- #) or torch.distributed.get_rank() == 0
- for st_file in tqdm(
- hf_weights_files,
- desc="Loading safetensors checkpoint shards",
- disable=not enable_tqdm,
- ):
- with safe_open(st_file, framework="pt") as f:
- for name in f.keys(): # noqa: SIM118
- param = f.get_tensor(name)
- yield name, param
- def pt_weights_iterator(
- hf_weights_files: List[str]
- ) -> Generator[Tuple[str, torch.Tensor], None, None]:
- """Iterate over the weights in the model bin/pt files."""
- enable_tqdm = False #not torch.distributed.is_initialized(
- #) or torch.distributed.get_rank() == 0
- for bin_file in tqdm(
- hf_weights_files,
- desc="Loading pt checkpoint shards",
- disable=not enable_tqdm,
- ):
- state = torch.load(bin_file, map_location="cpu")
- for name, param in state.items():
- yield name, param
- del state
- torch.cuda.empty_cache()
- def get_model_config_yaml(
- model_name_or_path: str,
- cache_dir: Optional[str] = None) -> Optional[dict]:
- """Look for aphrodite_config.yaml in model directory or HF repo.
- Args:
- model_name_or_path: Local path or HF model name
- cache_dir: Optional cache directory for HF downloads
- Returns:
- Dict containing the config if found, None otherwise
- """
- is_local = os.path.isdir(model_name_or_path)
- config_path = None
- if is_local:
- config_path = os.path.join(model_name_or_path, "aphrodite_config.yaml")
- if not os.path.exists(config_path):
- return None
- else:
- try:
- with get_lock(model_name_or_path, cache_dir):
- valid_names = ["aphrodite_config.yaml",
- "aphrodite_config.yml"]
- for name in valid_names:
- config_path = hf_hub_download(
- model_name_or_path,
- filename=name,
- cache_dir=cache_dir,
- local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
- )
- if os.path.exists(config_path):
- break
- except (huggingface_hub.utils.EntryNotFoundError,
- huggingface_hub.utils.LocalEntryNotFoundError):
- return None
- try:
- import yaml
- with open(config_path, 'r') as f:
- config = yaml.safe_load(f)
- return config
- except Exception as e:
- logger.warning(f"Failed to load aphrodite_config.yaml: {e}")
- return None
- def get_gguf_extra_tensor_names(
- gguf_file: str, gguf_to_hf_name_map: Dict[str, str]) -> List[str]:
- reader = gguf.GGUFReader(gguf_file)
- expected_gguf_keys = set(gguf_to_hf_name_map.keys())
- exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
- extra_keys = expected_gguf_keys - exact_gguf_keys
- return [gguf_to_hf_name_map[key] for key in extra_keys]
- def gguf_quant_weights_iterator(
- gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
- ) -> Generator[Tuple[str, torch.Tensor], None, None]:
- """
- Iterate over the quant weights in the model gguf files and convert
- them to torch tensors
- """
- reader = gguf.GGUFReader(gguf_file)
- for tensor in reader.tensors:
- if tensor.name in gguf_to_hf_name_map:
- weight_type = tensor.tensor_type
- name = gguf_to_hf_name_map[tensor.name]
- if weight_type.name != "F32":
- weight_type_name = name.replace("weight", "qweight_type")
- weight_type = torch.tensor(weight_type)
- yield weight_type_name, weight_type
- for tensor in reader.tensors:
- if tensor.name in gguf_to_hf_name_map:
- weight = tensor.data
- weight_type = tensor.tensor_type
- name = gguf_to_hf_name_map[tensor.name]
- if weight_type.name != "F32":
- name = name.replace("weight", "qweight")
- param = torch.tensor(weight)
- yield name, param
- def kv_cache_scales_loader(
- filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
- model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
- """
- A simple utility to read in KV cache scaling factors that have been
- previously serialized to disk. Used by the model to populate the appropriate
- KV cache scaling factors. The serialization should represent a dictionary
- whose keys are the TP ranks and values are another dictionary mapping layers
- to their KV cache scaling factors.
- Keep this function in sync with the output of examples/fp8/extract_scales.py
- """
- try:
- with open(filename) as f:
- context = {
- "model_type": model_type,
- "num_hidden_layers": num_hidden_layers,
- "tp_rank": tp_rank,
- "tp_size": tp_size,
- }
- schema_dct = json.load(f)
- schema = QuantParamSchema.model_validate(schema_dct,
- context=context)
- layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
- return layer_scales_map.items()
- except FileNotFoundError:
- logger.error(f"File or directory '{filename}' not found.")
- except json.JSONDecodeError:
- logger.error(f"Error decoding JSON in file '{filename}'.")
- except Exception as e:
- logger.error(f"An error occurred while reading '{filename}': {e}")
- # This section is reached if and only if any of the excepts are hit
- # Return an empty iterable (list) => no KV cache scales are loaded
- # which ultimately defaults to 1.0 scales
- logger.warning("Defaulting to KV cache scaling factors = 1.0 "
- f"for all layers in TP rank {tp_rank} "
- "as an error occurred during loading.")
- return []
- def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
- """convert PySafeSlice object from safetensors to torch.Tensor
- PySafeSlice object supports indexing, which is done before loading the
- actual tensor and can reduce the amount of memory being read into the
- memory. However, it does not support more advanced functionalities
- like `.view()` or `.t()`. Therefore, if we need to modify the loaded
- tensor with these more complicated operators, we need to convert to
- tensor first.
- """
- if not isinstance(x, torch.Tensor):
- x = x[:]
- return x
- def default_weight_loader(param: torch.Tensor,
- loaded_weight: torch.Tensor) -> None:
- """Default weight loader."""
- try:
- if param.numel() == 1 and loaded_weight.numel() == 1:
- # Sometimes scalar values aren't considered tensors with shapes
- # so if both param and loaded_weight are a scalar,
- # "broadcast" instead of copy
- param.data.fill_(loaded_weight.item())
- else:
- assert param.size() == loaded_weight.size(), (
- f"Attempted to load weight ({loaded_weight.size()}) "
- f"into parameter ({param.size()})")
- param.data.copy_(loaded_weight)
- except Exception:
- # NOTE: This exception is added for the purpose of setting breakpoint to
- # debug weight loading issues.
- raise
- def row_parallel_weight_loader(param: torch.Tensor,
- loaded_weight: torch.Tensor) -> None:
- """Load weights that are row-parallelized."""
- tp_rank = get_tensor_model_parallel_rank()
- shard_dim = 0 if param.dim() != 1 else None
- if shard_dim is not None:
- shard_size = param.data.shape[shard_dim]
- start_idx = tp_rank * shard_size
- loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
- return default_weight_loader(param, loaded_weight)
- def initialize_dummy_weights(
- model: torch.nn.Module,
- low: float = -1e-3,
- high: float = 1e-3,
- seed: int = 1234,
- ) -> None:
- """Initialize model weights with random values.
- The model weights must be randomly initialized for accurate performance
- measurements. Additionally, the model weights should not cause NaNs in the
- forward pass. We empirically found that initializing the weights with
- values between -1e-3 and 1e-3 works well for most models.
- We use per-parameter random seed, so that dummy weights are consistent,
- even if the model is partitioned across multiple devices. When the seed
- is fixed, the random values generated by this function only depends on
- the parameter's number of elements and its data type.
- """
- for param in model.state_dict().values():
- if torch.is_floating_point(param):
- if current_platform.is_tpu():
- # XLA device does not support torch.Generator()
- param.uniform_(low, high)
- continue
- generator = torch.Generator(device=param.data.device)
- generator.manual_seed(seed)
- if torch.finfo(param.data.dtype).bits < 16:
- # uniform_ doesn't support < 16-bit datatypes (FP8)
- dtype = param.data.dtype
- tmp_param = param.data.to(torch.float16)
- tmp_param = tmp_param.uniform_(low, high,
- generator=generator).to(dtype)
- tmp_param = tmp_param.uniform_(low, high).to(dtype)
- param.data.copy_(tmp_param)
- else:
- param.uniform_(low, high, generator=generator)
- def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
- """Remap the name of FP8 k/v_scale parameters.
- This function handles the remapping of FP8 k/v_scale parameter names.
- It detects if the given name ends with a suffix and attempts to remap
- it to the expected name format in the model. If the remapped name is not
- found in the params_dict, a warning is printed and None is returned.
- Args:
- name (str): The original loaded checkpoint parameter name.
- params_dict (dict): Dictionary containing the model's named parameters.
- Returns:
- str: The remapped parameter name if successful, or the original name
- if no remapping is needed.
- None: If the remapped name is not found in params_dict.
- """
- if name.endswith(".kv_scale"):
- print_warning_once(
- "DEPRECATED. Found kv_scale in the checkpoint. "
- "This format is deprecated in favor of separate k_scale and "
- "v_scale tensors and will be removed in a future release. "
- "Functionally, we will remap kv_scale to k_scale and duplicate "
- "k_scale to v_scale")
- # NOTE: we remap the deprecated kv_scale to k_scale
- remapped_name = name.replace(".kv_scale", ".attn.k_scale")
- if remapped_name not in params_dict:
- print_warning_once(
- f"Found kv_scale in the checkpoint (e.g. {name}), "
- "but not found the expected name in the model "
- f"(e.g. {remapped_name}). kv_scale is "
- "not loaded.")
- return None
- return remapped_name
- possible_scale_names = [".k_scale", ".v_scale"]
- for scale_name in possible_scale_names:
- if name.endswith(scale_name):
- remapped_name = name.replace(scale_name, f".attn{scale_name}")
- if remapped_name not in params_dict:
- print_warning_once(
- f"Found {scale_name} in the checkpoint (e.g. {name}), "
- "but not found the expected name in the model "
- f"(e.g. {remapped_name}). {scale_name} is "
- "not loaded.")
- return None
- return remapped_name
- # If there were no matches, return the untouched param name
- return name
|