1
0

weight_utils.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624
  1. """Utilities for downloading and initializing model weights."""
  2. import fnmatch
  3. import glob
  4. import hashlib
  5. import json
  6. import os
  7. import tempfile
  8. from collections import defaultdict
  9. from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
  10. import filelock
  11. import gguf
  12. import huggingface_hub.constants
  13. import numpy as np
  14. import torch
  15. from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
  16. from loguru import logger
  17. from safetensors.torch import load_file, safe_open, save_file
  18. from tqdm.auto import tqdm
  19. from aphrodite.common.config import LoadConfig, ModelConfig
  20. from aphrodite.common.utils import print_warning_once
  21. from aphrodite.distributed import get_tensor_model_parallel_rank
  22. from aphrodite.platforms import current_platform
  23. from aphrodite.quantization import QuantizationConfig, get_quantization_config
  24. from aphrodite.quantization.schema import QuantParamSchema
  25. # use system-level temp directory for file locks, so that multiple users
  26. # can share the same lock without error.
  27. # lock files in the temp directory will be automatically deleted when the
  28. # system reboots, so users will not complain about annoying lock files
  29. temp_dir = tempfile.gettempdir()
  30. def enable_hf_transfer():
  31. """automatically activates hf_transfer
  32. """
  33. if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
  34. try:
  35. # enable hf hub transfer if available
  36. import hf_transfer # type: ignore # noqa
  37. huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
  38. except ImportError:
  39. pass
  40. enable_hf_transfer()
  41. class DisabledTqdm(tqdm):
  42. def __init__(self, *args, **kwargs):
  43. super().__init__(*args, **kwargs, disable=True)
  44. def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
  45. lock_dir = cache_dir or temp_dir
  46. os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
  47. model_name = model_name_or_path.replace("/", "-")
  48. hash_name = hashlib.sha256(model_name.encode()).hexdigest()
  49. # add hash to avoid conflict with old users' lock files
  50. lock_file_name = hash_name + model_name + ".lock"
  51. # mode 0o666 is required for the filelock to be shared across users
  52. lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
  53. mode=0o666)
  54. return lock
  55. def _shared_pointers(tensors):
  56. ptrs = defaultdict(list)
  57. for k, v in tensors.items():
  58. ptrs[v.data_ptr()].append(k)
  59. failing = []
  60. for _, names in ptrs.items():
  61. if len(names) > 1:
  62. failing.append(names)
  63. return failing
  64. def convert_bin_to_safetensor_file(
  65. pt_filename: str,
  66. sf_filename: str,
  67. ) -> None:
  68. loaded = torch.load(pt_filename, map_location="cpu")
  69. if "state_dict" in loaded:
  70. loaded = loaded["state_dict"]
  71. shared = _shared_pointers(loaded)
  72. for shared_weights in shared:
  73. for name in shared_weights[1:]:
  74. loaded.pop(name)
  75. # For tensors to be contiguous
  76. loaded = {k: v.contiguous() for k, v in loaded.items()}
  77. dirname = os.path.dirname(sf_filename)
  78. os.makedirs(dirname, exist_ok=True)
  79. save_file(loaded, sf_filename, metadata={"format": "pt"})
  80. # check file size
  81. sf_size = os.stat(sf_filename).st_size
  82. pt_size = os.stat(pt_filename).st_size
  83. if (sf_size - pt_size) / pt_size > 0.01:
  84. raise RuntimeError(f"""The file size different is more than 1%:
  85. - {sf_filename}: {sf_size}
  86. - {pt_filename}: {pt_size}
  87. """)
  88. # check if the tensors are the same
  89. reloaded = load_file(sf_filename)
  90. for k in loaded:
  91. pt_tensor = loaded[k]
  92. sf_tensor = reloaded[k]
  93. if not torch.equal(pt_tensor, sf_tensor):
  94. raise RuntimeError(f"The output tensors do not match for key {k}")
  95. # TODO: Move this to other place.
  96. def get_quant_config(model_config: ModelConfig,
  97. load_config: LoadConfig) -> QuantizationConfig:
  98. quant_cls = get_quantization_config(model_config.quantization)
  99. # GGUF doesn't have config file
  100. if model_config.quantization == "gguf":
  101. return quant_cls.from_config({})
  102. # Read the quantization config from the HF model config, if available.
  103. hf_quant_config = getattr(model_config.hf_config, "quantization_config",
  104. None)
  105. # some vision model may keep quantization_config in their text_config
  106. hf_text_config = getattr(model_config.hf_config, "text_config", None)
  107. if hf_quant_config is None and hf_text_config is not None:
  108. hf_quant_config = getattr(hf_text_config, "quantization_config", None)
  109. if hf_quant_config is None:
  110. # compressed-tensors uses a compressions_config
  111. hf_quant_config = getattr(model_config.hf_config, "compression_config",
  112. None)
  113. if hf_quant_config is not None:
  114. return quant_cls.from_config(hf_quant_config)
  115. # In case of bitsandbytes/QLoRA, get quant config from the adapter model.
  116. if model_config.quantization == "bitsandbytes":
  117. if (not load_config.model_loader_extra_config
  118. or "qlora_adapter_name_or_path"
  119. not in load_config.model_loader_extra_config):
  120. return quant_cls.from_config({"adapter_name_or_path": ""})
  121. model_name_or_path = load_config.model_loader_extra_config[
  122. "qlora_adapter_name_or_path"]
  123. else:
  124. model_name_or_path = model_config.model
  125. is_local = os.path.isdir(model_name_or_path)
  126. if not is_local:
  127. # Download the config files.
  128. with get_lock(model_name_or_path, load_config.download_dir):
  129. hf_folder = snapshot_download(
  130. model_name_or_path,
  131. revision=model_config.revision,
  132. allow_patterns="*.json",
  133. cache_dir=load_config.download_dir,
  134. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  135. tqdm_class=DisabledTqdm,
  136. )
  137. else:
  138. hf_folder = model_name_or_path
  139. possible_config_filenames = quant_cls.get_config_filenames()
  140. # If the quantization config is not found, use the default config.
  141. if not possible_config_filenames:
  142. return quant_cls()
  143. config_files = glob.glob(os.path.join(hf_folder, "*.json"))
  144. quant_config_files = [
  145. f for f in config_files if any(
  146. f.endswith(x) for x in possible_config_filenames)
  147. ]
  148. if len(quant_config_files) == 0:
  149. raise ValueError(
  150. f"Cannot find the config file for {model_config.quantization}")
  151. if len(quant_config_files) > 1:
  152. raise ValueError(
  153. f"Found multiple config files for {model_config.quantization}: "
  154. f"{quant_config_files}")
  155. quant_config_file = quant_config_files[0]
  156. with open(quant_config_file, "r") as f:
  157. config = json.load(f)
  158. if model_config.quantization == "bitsandbytes":
  159. config["adapter_name_or_path"] = model_name_or_path
  160. return quant_cls.from_config(config)
  161. def download_weights_from_hf(
  162. model_name_or_path: str,
  163. cache_dir: Optional[str],
  164. allow_patterns: List[str],
  165. revision: Optional[str] = None,
  166. ignore_patterns: Optional[Union[str, List[str]]] = None,
  167. ) -> str:
  168. """Download model weights from Hugging Face Hub.
  169. Args:
  170. model_name_or_path (str): The model name or path.
  171. cache_dir (Optional[str]): The cache directory to store the model
  172. weights. If None, will use HF defaults.
  173. allow_patterns (List[str]): The allowed patterns for the
  174. weight files. Files matched by any of the patterns will be
  175. downloaded.
  176. revision (Optional[str]): The revision of the model.
  177. ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
  178. filter out the weight files. Files matched by any of the patterns
  179. will be ignored.
  180. Returns:
  181. str: The path to the downloaded model weights.
  182. """
  183. if not huggingface_hub.constants.HF_HUB_OFFLINE:
  184. # Before we download we look at that is available:
  185. fs = HfFileSystem()
  186. file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
  187. # depending on what is available we download different things
  188. for pattern in allow_patterns:
  189. matching = fnmatch.filter(file_list, pattern)
  190. if len(matching) > 0:
  191. allow_patterns = [pattern]
  192. break
  193. rank = get_tensor_model_parallel_rank()
  194. if rank == 0:
  195. logger.info(f"Using model weights format {allow_patterns}")
  196. # Use file lock to prevent multiple processes from
  197. # downloading the same model weights at the same time.
  198. with get_lock(model_name_or_path, cache_dir):
  199. hf_folder = snapshot_download(
  200. model_name_or_path,
  201. allow_patterns=allow_patterns,
  202. ignore_patterns=ignore_patterns,
  203. cache_dir=cache_dir,
  204. tqdm_class=DisabledTqdm,
  205. revision=revision,
  206. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  207. )
  208. return hf_folder
  209. def download_safetensors_index_file_from_hf(
  210. model_name_or_path: str,
  211. index_file: str,
  212. cache_dir: Optional[str],
  213. revision: Optional[str] = None,
  214. ) -> None:
  215. """Download hf safetensors index file from Hugging Face Hub.
  216. Args:
  217. model_name_or_path (str): The model name or path.
  218. cache_dir (Optional[str]): The cache directory to store the model
  219. weights. If None, will use HF defaults.
  220. revision (Optional[str]): The revision of the model.
  221. """
  222. # Use file lock to prevent multiple processes from
  223. # downloading the same model weights at the same time.
  224. with get_lock(model_name_or_path, cache_dir):
  225. try:
  226. # Download the safetensors index file.
  227. hf_hub_download(
  228. repo_id=model_name_or_path,
  229. filename=index_file,
  230. cache_dir=cache_dir,
  231. revision=revision,
  232. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  233. )
  234. # If file not found on remote or locally, we should not fail since
  235. # only some models will have index_file.
  236. except huggingface_hub.utils.EntryNotFoundError:
  237. logger.info(f"No {index_file} found in remote.")
  238. except huggingface_hub.utils.LocalEntryNotFoundError:
  239. logger.info(f"No {index_file} found in local cache.")
  240. # For models like Mistral-7B-v0.3, there are both sharded
  241. # safetensors files and a consolidated safetensors file.
  242. # Passing both of these to the weight loader functionality breaks.
  243. # So, we use the index_file to
  244. # look up which safetensors files should be used.
  245. def filter_duplicate_safetensors_files(hf_weights_files: List[str],
  246. hf_folder: str,
  247. index_file: str) -> List[str]:
  248. # model.safetensors.index.json is a mapping from keys in the
  249. # torch state_dict to safetensors file holding that weight.
  250. index_file_name = os.path.join(hf_folder, index_file)
  251. if not os.path.isfile(index_file_name):
  252. return hf_weights_files
  253. # Iterate through the weight_map (weight_name: safetensors files)
  254. # to identify weights that we should use.
  255. with open(index_file_name, "r") as f:
  256. weight_map = json.load(f)["weight_map"]
  257. weight_files_in_index = set()
  258. for weight_name in weight_map:
  259. weight_files_in_index.add(
  260. os.path.join(hf_folder, weight_map[weight_name]))
  261. # Filter out any fields that are not found in the index file.
  262. hf_weights_files = [
  263. f for f in hf_weights_files if f in weight_files_in_index
  264. ]
  265. return hf_weights_files
  266. def filter_files_not_needed_for_inference(
  267. hf_weights_files: List[str]) -> List[str]:
  268. """
  269. Exclude files that are not needed for inference.
  270. See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
  271. """
  272. blacklist = [
  273. "training_args.bin",
  274. "optimizer.bin",
  275. "optimizer.pt",
  276. "scheduler.pt",
  277. "scaler.pt",
  278. ]
  279. hf_weights_files = [
  280. f for f in hf_weights_files
  281. if not any(f.endswith(x) for x in blacklist)
  282. ]
  283. return hf_weights_files
  284. def np_cache_weights_iterator(
  285. model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
  286. hf_weights_files: List[str]
  287. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  288. """Iterate over the weights in the model np files.
  289. Will dump the model weights to numpy files if they are not already dumped.
  290. """
  291. enable_tqdm = False #not torch.distributed.is_initialized(
  292. #) or torch.distributed.get_rank() == 0
  293. # Convert the model weights from torch tensors to numpy arrays for
  294. # faster loading.
  295. np_folder = os.path.join(hf_folder, "np")
  296. os.makedirs(np_folder, exist_ok=True)
  297. weight_names_file = os.path.join(np_folder, "weight_names.json")
  298. # Use file lock to prevent multiple processes from
  299. # dumping the same model weights to numpy at the same time.
  300. with get_lock(model_name_or_path, cache_dir):
  301. if not os.path.exists(weight_names_file):
  302. weight_names = []
  303. for bin_file in tqdm(
  304. hf_weights_files,
  305. desc="Loading np_cache checkpoint shards",
  306. disable=not enable_tqdm,
  307. ):
  308. state = torch.load(bin_file, map_location="cpu")
  309. for name, param in state.items():
  310. param_path = os.path.join(np_folder, name)
  311. with open(param_path, "wb") as f:
  312. np.save(f, param.cpu().detach().numpy())
  313. weight_names.append(name)
  314. with open(weight_names_file, "w") as f:
  315. json.dump(weight_names, f)
  316. with open(weight_names_file, "r") as f:
  317. weight_names = json.load(f)
  318. for name in weight_names:
  319. param_path = os.path.join(np_folder, name)
  320. with open(param_path, "rb") as f:
  321. param = np.load(f)
  322. yield name, torch.from_numpy(param)
  323. def safetensors_weights_iterator(
  324. hf_weights_files: List[str]
  325. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  326. """Iterate over the weights in the model safetensor files."""
  327. enable_tqdm = False #not torch.distributed.is_initialized(
  328. #) or torch.distributed.get_rank() == 0
  329. for st_file in tqdm(
  330. hf_weights_files,
  331. desc="Loading safetensors checkpoint shards",
  332. disable=not enable_tqdm,
  333. ):
  334. with safe_open(st_file, framework="pt") as f:
  335. for name in f.keys(): # noqa: SIM118
  336. param = f.get_tensor(name)
  337. yield name, param
  338. def pt_weights_iterator(
  339. hf_weights_files: List[str]
  340. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  341. """Iterate over the weights in the model bin/pt files."""
  342. enable_tqdm = False #not torch.distributed.is_initialized(
  343. #) or torch.distributed.get_rank() == 0
  344. for bin_file in tqdm(
  345. hf_weights_files,
  346. desc="Loading pt checkpoint shards",
  347. disable=not enable_tqdm,
  348. ):
  349. state = torch.load(bin_file, map_location="cpu")
  350. for name, param in state.items():
  351. yield name, param
  352. del state
  353. torch.cuda.empty_cache()
  354. def get_gguf_extra_tensor_names(
  355. gguf_file: str, gguf_to_hf_name_map: Dict[str, str]) -> List[str]:
  356. reader = gguf.GGUFReader(gguf_file)
  357. expected_gguf_keys = set(gguf_to_hf_name_map.keys())
  358. exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
  359. extra_keys = expected_gguf_keys - exact_gguf_keys
  360. return [gguf_to_hf_name_map[key] for key in extra_keys]
  361. def gguf_quant_weights_iterator(
  362. gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
  363. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  364. """
  365. Iterate over the quant weights in the model gguf files and convert
  366. them to torch tensors
  367. """
  368. reader = gguf.GGUFReader(gguf_file)
  369. for tensor in reader.tensors:
  370. if tensor.name in gguf_to_hf_name_map:
  371. weight_type = tensor.tensor_type
  372. name = gguf_to_hf_name_map[tensor.name]
  373. if weight_type.name != "F32":
  374. weight_type_name = name.replace("weight", "qweight_type")
  375. weight_type = torch.tensor(weight_type)
  376. yield weight_type_name, weight_type
  377. for tensor in reader.tensors:
  378. if tensor.name in gguf_to_hf_name_map:
  379. weight = tensor.data
  380. weight_type = tensor.tensor_type
  381. name = gguf_to_hf_name_map[tensor.name]
  382. if weight_type.name != "F32":
  383. name = name.replace("weight", "qweight")
  384. param = torch.tensor(weight)
  385. yield name, param
  386. def kv_cache_scales_loader(
  387. filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
  388. model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
  389. """
  390. A simple utility to read in KV cache scaling factors that have been
  391. previously serialized to disk. Used by the model to populate the appropriate
  392. KV cache scaling factors. The serialization should represent a dictionary
  393. whose keys are the TP ranks and values are another dictionary mapping layers
  394. to their KV cache scaling factors.
  395. Keep this function in sync with the output of examples/fp8/extract_scales.py
  396. """
  397. try:
  398. with open(filename) as f:
  399. context = {
  400. "model_type": model_type,
  401. "num_hidden_layers": num_hidden_layers,
  402. "tp_rank": tp_rank,
  403. "tp_size": tp_size,
  404. }
  405. schema_dct = json.load(f)
  406. schema = QuantParamSchema.model_validate(schema_dct,
  407. context=context)
  408. layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
  409. return layer_scales_map.items()
  410. except FileNotFoundError:
  411. logger.error(f"File or directory '{filename}' not found.")
  412. except json.JSONDecodeError:
  413. logger.error(f"Error decoding JSON in file '{filename}'.")
  414. except Exception as e:
  415. logger.error(f"An error occurred while reading '{filename}': {e}")
  416. # This section is reached if and only if any of the excepts are hit
  417. # Return an empty iterable (list) => no KV cache scales are loaded
  418. # which ultimately defaults to 1.0 scales
  419. logger.warning("Defaulting to KV cache scaling factors = 1.0 "
  420. f"for all layers in TP rank {tp_rank} "
  421. "as an error occurred during loading.")
  422. return []
  423. def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
  424. """convert PySafeSlice object from safetensors to torch.Tensor
  425. PySafeSlice object supports indexing, which is done before loading the
  426. actual tensor and can reduce the amount of memory being read into the
  427. memory. However, it does not support more advanced functionalities
  428. like `.view()` or `.t()`. Therefore, if we need to modify the loaded
  429. tensor with these more complicated operators, we need to convert to
  430. tensor first.
  431. """
  432. if not isinstance(x, torch.Tensor):
  433. x = x[:]
  434. return x
  435. def default_weight_loader(param: torch.Tensor,
  436. loaded_weight: torch.Tensor) -> None:
  437. """Default weight loader."""
  438. try:
  439. if param.numel() == 1 and loaded_weight.numel() == 1:
  440. # Sometimes scalar values aren't considered tensors with shapes
  441. # so if both param and loaded_weight are a scalar,
  442. # "broadcast" instead of copy
  443. param.data.fill_(loaded_weight.item())
  444. else:
  445. assert param.size() == loaded_weight.size(), (
  446. f"Attempted to load weight ({loaded_weight.size()}) "
  447. f"into parameter ({param.size()})")
  448. param.data.copy_(loaded_weight)
  449. except Exception:
  450. # NOTE: This exception is added for the purpose of setting breakpoint to
  451. # debug weight loading issues.
  452. raise
  453. def row_parallel_weight_loader(param: torch.Tensor,
  454. loaded_weight: torch.Tensor) -> None:
  455. """Load weights that are row-parallelized."""
  456. tp_rank = get_tensor_model_parallel_rank()
  457. shard_dim = 0 if param.dim() != 1 else None
  458. if shard_dim is not None:
  459. shard_size = param.data.shape[shard_dim]
  460. start_idx = tp_rank * shard_size
  461. loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
  462. return default_weight_loader(param, loaded_weight)
  463. def initialize_dummy_weights(
  464. model: torch.nn.Module,
  465. low: float = -1e-3,
  466. high: float = 1e-3,
  467. seed: int = 1234,
  468. ) -> None:
  469. """Initialize model weights with random values.
  470. The model weights must be randomly initialized for accurate performance
  471. measurements. Additionally, the model weights should not cause NaNs in the
  472. forward pass. We empirically found that initializing the weights with
  473. values between -1e-3 and 1e-3 works well for most models.
  474. We use per-parameter random seed, so that dummy weights are consistent,
  475. even if the model is partitioned across multiple devices. When the seed
  476. is fixed, the random values generated by this function only depends on
  477. the parameter's number of elements and its data type.
  478. """
  479. for param in model.state_dict().values():
  480. if torch.is_floating_point(param):
  481. if current_platform.is_tpu():
  482. # XLA device does not support torch.Generator()
  483. param.uniform_(low, high)
  484. continue
  485. generator = torch.Generator(device=param.data.device)
  486. generator.manual_seed(seed)
  487. if torch.finfo(param.data.dtype).bits < 16:
  488. # uniform_ doesn't support < 16-bit datatypes (FP8)
  489. dtype = param.data.dtype
  490. tmp_param = param.data.to(torch.float16)
  491. tmp_param = tmp_param.uniform_(low, high,
  492. generator=generator).to(dtype)
  493. tmp_param = tmp_param.uniform_(low, high).to(dtype)
  494. param.data.copy_(tmp_param)
  495. else:
  496. param.uniform_(low, high, generator=generator)
  497. def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
  498. """Remap the name of FP8 k/v_scale parameters.
  499. This function handles the remapping of FP8 k/v_scale parameter names.
  500. It detects if the given name ends with a suffix and attempts to remap
  501. it to the expected name format in the model. If the remapped name is not
  502. found in the params_dict, a warning is printed and None is returned.
  503. Args:
  504. name (str): The original loaded checkpoint parameter name.
  505. params_dict (dict): Dictionary containing the model's named parameters.
  506. Returns:
  507. str: The remapped parameter name if successful, or the original name
  508. if no remapping is needed.
  509. None: If the remapped name is not found in params_dict.
  510. """
  511. if name.endswith(".kv_scale"):
  512. print_warning_once(
  513. "DEPRECATED. Found kv_scale in the checkpoint. "
  514. "This format is deprecated in favor of separate k_scale and "
  515. "v_scale tensors and will be removed in a future release. "
  516. "Functionally, we will remap kv_scale to k_scale and duplicate "
  517. "k_scale to v_scale")
  518. # NOTE: we remap the deprecated kv_scale to k_scale
  519. remapped_name = name.replace(".kv_scale", ".attn.k_scale")
  520. if remapped_name not in params_dict:
  521. print_warning_once(
  522. f"Found kv_scale in the checkpoint (e.g. {name}), "
  523. "but not found the expected name in the model "
  524. f"(e.g. {remapped_name}). kv_scale is "
  525. "not loaded.")
  526. return None
  527. return remapped_name
  528. possible_scale_names = [".k_scale", ".v_scale"]
  529. for scale_name in possible_scale_names:
  530. if name.endswith(scale_name):
  531. remapped_name = name.replace(scale_name, f".attn{scale_name}")
  532. if remapped_name not in params_dict:
  533. print_warning_once(
  534. f"Found {scale_name} in the checkpoint (e.g. {name}), "
  535. "but not found the expected name in the model "
  536. f"(e.g. {remapped_name}). {scale_name} is "
  537. "not loaded.")
  538. return None
  539. return remapped_name
  540. # If there were no matches, return the untouched param name
  541. return name