weight_utils.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619
  1. """Utilities for downloading and initializing model weights."""
  2. import fnmatch
  3. import glob
  4. import hashlib
  5. import json
  6. import os
  7. import tempfile
  8. from collections import defaultdict
  9. from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
  10. import filelock
  11. import gguf
  12. import huggingface_hub.constants
  13. import numpy as np
  14. import torch
  15. from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
  16. from loguru import logger
  17. from safetensors.torch import load_file, safe_open, save_file
  18. from tqdm.auto import tqdm
  19. from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
  20. from aphrodite.common.config import LoadConfig, ModelConfig
  21. from aphrodite.common.utils import print_warning_once
  22. from aphrodite.distributed import get_tensor_model_parallel_rank
  23. from aphrodite.platforms import current_platform
  24. from aphrodite.quantization import QuantizationConfig, get_quantization_config
  25. from aphrodite.quantization.schema import QuantParamSchema
  26. # use system-level temp directory for file locks, so that multiple users
  27. # can share the same lock without error.
  28. # lock files in the temp directory will be automatically deleted when the
  29. # system reboots, so users will not complain about annoying lock files
  30. temp_dir = tempfile.gettempdir()
  31. def enable_hf_transfer():
  32. """automatically activates hf_transfer
  33. """
  34. if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
  35. try:
  36. # enable hf hub transfer if available
  37. import hf_transfer # type: ignore # noqa
  38. huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
  39. except ImportError:
  40. pass
  41. enable_hf_transfer()
  42. class DisabledTqdm(tqdm):
  43. def __init__(self, *args, **kwargs):
  44. super().__init__(*args, **kwargs, disable=True)
  45. def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
  46. lock_dir = cache_dir or temp_dir
  47. os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
  48. model_name = model_name_or_path.replace("/", "-")
  49. hash_name = hashlib.sha256(model_name.encode()).hexdigest()
  50. # add hash to avoid conflict with old users' lock files
  51. lock_file_name = hash_name + model_name + ".lock"
  52. # mode 0o666 is required for the filelock to be shared across users
  53. lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
  54. mode=0o666)
  55. return lock
  56. def _shared_pointers(tensors):
  57. ptrs = defaultdict(list)
  58. for k, v in tensors.items():
  59. ptrs[v.data_ptr()].append(k)
  60. failing = []
  61. for _, names in ptrs.items():
  62. if len(names) > 1:
  63. failing.append(names)
  64. return failing
  65. def convert_bin_to_safetensor_file(
  66. pt_filename: str,
  67. sf_filename: str,
  68. ) -> None:
  69. loaded = torch.load(pt_filename, map_location="cpu")
  70. if "state_dict" in loaded:
  71. loaded = loaded["state_dict"]
  72. shared = _shared_pointers(loaded)
  73. for shared_weights in shared:
  74. for name in shared_weights[1:]:
  75. loaded.pop(name)
  76. # For tensors to be contiguous
  77. loaded = {k: v.contiguous() for k, v in loaded.items()}
  78. dirname = os.path.dirname(sf_filename)
  79. os.makedirs(dirname, exist_ok=True)
  80. save_file(loaded, sf_filename, metadata={"format": "pt"})
  81. # check file size
  82. sf_size = os.stat(sf_filename).st_size
  83. pt_size = os.stat(pt_filename).st_size
  84. if (sf_size - pt_size) / pt_size > 0.01:
  85. raise RuntimeError(f"""The file size different is more than 1%:
  86. - {sf_filename}: {sf_size}
  87. - {pt_filename}: {pt_size}
  88. """)
  89. # check if the tensors are the same
  90. reloaded = load_file(sf_filename)
  91. for k in loaded:
  92. pt_tensor = loaded[k]
  93. sf_tensor = reloaded[k]
  94. if not torch.equal(pt_tensor, sf_tensor):
  95. raise RuntimeError(f"The output tensors do not match for key {k}")
  96. # TODO: Move this to other place.
  97. def get_quant_config(model_config: ModelConfig,
  98. load_config: LoadConfig) -> QuantizationConfig:
  99. quant_cls = get_quantization_config(model_config.quantization)
  100. # GGUF doesn't have config file
  101. if model_config.quantization == "gguf":
  102. return quant_cls.from_config({})
  103. # Read the quantization config from the HF model config, if available.
  104. hf_quant_config = getattr(model_config.hf_config, "quantization_config",
  105. None)
  106. if hf_quant_config is None:
  107. # compressed-tensors uses a compressions_config
  108. hf_quant_config = getattr(model_config.hf_config, "compression_config",
  109. None)
  110. if hf_quant_config is not None:
  111. return quant_cls.from_config(hf_quant_config)
  112. # In case of bitsandbytes/QLoRA, get quant config from the adapter model.
  113. if model_config.quantization == "bitsandbytes":
  114. if (not load_config.model_loader_extra_config
  115. or "qlora_adapter_name_or_path"
  116. not in load_config.model_loader_extra_config):
  117. return quant_cls.from_config({"adapter_name_or_path": ""})
  118. model_name_or_path = load_config.model_loader_extra_config[
  119. "qlora_adapter_name_or_path"]
  120. else:
  121. model_name_or_path = model_config.model
  122. is_local = os.path.isdir(model_name_or_path)
  123. if not is_local:
  124. # Download the config files.
  125. with get_lock(model_name_or_path, load_config.download_dir):
  126. hf_folder = snapshot_download(
  127. model_name_or_path,
  128. revision=model_config.revision,
  129. allow_patterns="*.json",
  130. cache_dir=load_config.download_dir,
  131. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  132. tqdm_class=DisabledTqdm,
  133. )
  134. else:
  135. hf_folder = model_name_or_path
  136. possible_config_filenames = quant_cls.get_config_filenames()
  137. # If the quantization config is not found, use the default config.
  138. if not possible_config_filenames:
  139. return quant_cls()
  140. config_files = glob.glob(os.path.join(hf_folder, "*.json"))
  141. quant_config_files = [
  142. f for f in config_files if any(
  143. f.endswith(x) for x in possible_config_filenames)
  144. ]
  145. if len(quant_config_files) == 0:
  146. raise ValueError(
  147. f"Cannot find the config file for {model_config.quantization}")
  148. if len(quant_config_files) > 1:
  149. raise ValueError(
  150. f"Found multiple config files for {model_config.quantization}: "
  151. f"{quant_config_files}")
  152. quant_config_file = quant_config_files[0]
  153. with open(quant_config_file, "r") as f:
  154. config = json.load(f)
  155. if model_config.quantization == "bitsandbytes":
  156. config["adapter_name_or_path"] = model_name_or_path
  157. return quant_cls.from_config(config)
  158. def download_weights_from_hf(
  159. model_name_or_path: str,
  160. cache_dir: Optional[str],
  161. allow_patterns: List[str],
  162. revision: Optional[str] = None,
  163. ignore_patterns: Optional[Union[str, List[str]]] = None,
  164. ) -> str:
  165. """Download model weights from Hugging Face Hub.
  166. Args:
  167. model_name_or_path (str): The model name or path.
  168. cache_dir (Optional[str]): The cache directory to store the model
  169. weights. If None, will use HF defaults.
  170. allow_patterns (List[str]): The allowed patterns for the
  171. weight files. Files matched by any of the patterns will be
  172. downloaded.
  173. revision (Optional[str]): The revision of the model.
  174. ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
  175. filter out the weight files. Files matched by any of the patterns
  176. will be ignored.
  177. Returns:
  178. str: The path to the downloaded model weights.
  179. """
  180. if not huggingface_hub.constants.HF_HUB_OFFLINE:
  181. # Before we download we look at that is available:
  182. fs = HfFileSystem()
  183. file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
  184. # depending on what is available we download different things
  185. for pattern in allow_patterns:
  186. matching = fnmatch.filter(file_list, pattern)
  187. if len(matching) > 0:
  188. allow_patterns = [pattern]
  189. break
  190. rank = get_tensor_model_parallel_rank()
  191. if rank == 0:
  192. logger.info(f"Using model weights format {allow_patterns}")
  193. # Use file lock to prevent multiple processes from
  194. # downloading the same model weights at the same time.
  195. with get_lock(model_name_or_path, cache_dir):
  196. hf_folder = snapshot_download(
  197. model_name_or_path,
  198. allow_patterns=allow_patterns,
  199. ignore_patterns=ignore_patterns,
  200. cache_dir=cache_dir,
  201. tqdm_class=DisabledTqdm,
  202. revision=revision,
  203. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  204. )
  205. return hf_folder
  206. def download_safetensors_index_file_from_hf(
  207. model_name_or_path: str,
  208. cache_dir: Optional[str],
  209. revision: Optional[str] = None,
  210. ) -> None:
  211. """Download hf safetensors index file from Hugging Face Hub.
  212. Args:
  213. model_name_or_path (str): The model name or path.
  214. cache_dir (Optional[str]): The cache directory to store the model
  215. weights. If None, will use HF defaults.
  216. revision (Optional[str]): The revision of the model.
  217. """
  218. # Use file lock to prevent multiple processes from
  219. # downloading the same model weights at the same time.
  220. with get_lock(model_name_or_path, cache_dir):
  221. try:
  222. # Download the safetensors index file.
  223. hf_hub_download(
  224. repo_id=model_name_or_path,
  225. filename=SAFE_WEIGHTS_INDEX_NAME,
  226. cache_dir=cache_dir,
  227. revision=revision,
  228. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  229. )
  230. # If file not found on remote or locally, we should not fail since
  231. # only some models will have SAFE_WEIGHTS_INDEX_NAME.
  232. except huggingface_hub.utils.EntryNotFoundError:
  233. logger.info(f"No {SAFE_WEIGHTS_INDEX_NAME} found in remote.")
  234. except huggingface_hub.utils.LocalEntryNotFoundError:
  235. logger.info(f"No {SAFE_WEIGHTS_INDEX_NAME} found in local cache.")
  236. # For models like Mistral-7B-v0.3, there are both sharded
  237. # safetensors files and a consolidated safetensors file.
  238. # Passing both of these to the weight loader functionality breaks.
  239. # So, we use the SAFE_WEIGHTS_INDEX_NAME to
  240. # look up which safetensors files should be used.
  241. def filter_duplicate_safetensors_files(hf_weights_files: List[str],
  242. hf_folder: str) -> List[str]:
  243. # model.safetensors.index.json is a mapping from keys in the
  244. # torch state_dict to safetensors file holding that weight.
  245. index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
  246. if not os.path.isfile(index_file_name):
  247. return hf_weights_files
  248. # Iterate through the weight_map (weight_name: safetensors files)
  249. # to identify weights that we should use.
  250. with open(index_file_name) as index_file:
  251. weight_map = json.load(index_file)["weight_map"]
  252. weight_files_in_index = set()
  253. for weight_name in weight_map:
  254. weight_files_in_index.add(
  255. os.path.join(hf_folder, weight_map[weight_name]))
  256. # Filter out any fields that are not found in the index file.
  257. hf_weights_files = [
  258. f for f in hf_weights_files if f in weight_files_in_index
  259. ]
  260. return hf_weights_files
  261. def filter_files_not_needed_for_inference(
  262. hf_weights_files: List[str]) -> List[str]:
  263. """
  264. Exclude files that are not needed for inference.
  265. See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
  266. """
  267. blacklist = [
  268. "training_args.bin",
  269. "optimizer.bin",
  270. "optimizer.pt",
  271. "scheduler.pt",
  272. "scaler.pt",
  273. ]
  274. hf_weights_files = [
  275. f for f in hf_weights_files
  276. if not any(f.endswith(x) for x in blacklist)
  277. ]
  278. return hf_weights_files
  279. def np_cache_weights_iterator(
  280. model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
  281. hf_weights_files: List[str]
  282. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  283. """Iterate over the weights in the model np files.
  284. Will dump the model weights to numpy files if they are not already dumped.
  285. """
  286. enable_tqdm = False #not torch.distributed.is_initialized(
  287. #) or torch.distributed.get_rank() == 0
  288. # Convert the model weights from torch tensors to numpy arrays for
  289. # faster loading.
  290. np_folder = os.path.join(hf_folder, "np")
  291. os.makedirs(np_folder, exist_ok=True)
  292. weight_names_file = os.path.join(np_folder, "weight_names.json")
  293. # Use file lock to prevent multiple processes from
  294. # dumping the same model weights to numpy at the same time.
  295. with get_lock(model_name_or_path, cache_dir):
  296. if not os.path.exists(weight_names_file):
  297. weight_names = []
  298. for bin_file in tqdm(
  299. hf_weights_files,
  300. desc="Loading np_cache checkpoint shards",
  301. disable=not enable_tqdm,
  302. ):
  303. state = torch.load(bin_file, map_location="cpu")
  304. for name, param in state.items():
  305. param_path = os.path.join(np_folder, name)
  306. with open(param_path, "wb") as f:
  307. np.save(f, param.cpu().detach().numpy())
  308. weight_names.append(name)
  309. with open(weight_names_file, "w") as f:
  310. json.dump(weight_names, f)
  311. with open(weight_names_file, "r") as f:
  312. weight_names = json.load(f)
  313. for name in weight_names:
  314. param_path = os.path.join(np_folder, name)
  315. with open(param_path, "rb") as f:
  316. param = np.load(f)
  317. yield name, torch.from_numpy(param)
  318. def safetensors_weights_iterator(
  319. hf_weights_files: List[str]
  320. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  321. """Iterate over the weights in the model safetensor files."""
  322. enable_tqdm = False #not torch.distributed.is_initialized(
  323. #) or torch.distributed.get_rank() == 0
  324. for st_file in tqdm(
  325. hf_weights_files,
  326. desc="Loading safetensors checkpoint shards",
  327. disable=not enable_tqdm,
  328. ):
  329. with safe_open(st_file, framework="pt") as f:
  330. for name in f.keys(): # noqa: SIM118
  331. param = f.get_tensor(name)
  332. yield name, param
  333. def pt_weights_iterator(
  334. hf_weights_files: List[str]
  335. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  336. """Iterate over the weights in the model bin/pt files."""
  337. enable_tqdm = False #not torch.distributed.is_initialized(
  338. #) or torch.distributed.get_rank() == 0
  339. for bin_file in tqdm(
  340. hf_weights_files,
  341. desc="Loading pt checkpoint shards",
  342. disable=not enable_tqdm,
  343. ):
  344. state = torch.load(bin_file, map_location="cpu")
  345. for name, param in state.items():
  346. yield name, param
  347. del state
  348. torch.cuda.empty_cache()
  349. def get_gguf_extra_tensor_names(
  350. gguf_file: str, gguf_to_hf_name_map: Dict[str, str]) -> List[str]:
  351. reader = gguf.GGUFReader(gguf_file)
  352. expected_gguf_keys = set(gguf_to_hf_name_map.keys())
  353. exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
  354. extra_keys = expected_gguf_keys - exact_gguf_keys
  355. return [gguf_to_hf_name_map[key] for key in extra_keys]
  356. def gguf_quant_weights_iterator(
  357. gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
  358. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  359. """
  360. Iterate over the quant weights in the model gguf files and convert
  361. them to torch tensors
  362. """
  363. reader = gguf.GGUFReader(gguf_file)
  364. for tensor in reader.tensors:
  365. if tensor.name in gguf_to_hf_name_map:
  366. weight_type = tensor.tensor_type
  367. name = gguf_to_hf_name_map[tensor.name]
  368. if weight_type.name != "F32":
  369. weight_type_name = name.replace("weight", "qweight_type")
  370. weight_type = torch.tensor(weight_type)
  371. yield weight_type_name, weight_type
  372. for tensor in reader.tensors:
  373. if tensor.name in gguf_to_hf_name_map:
  374. weight = tensor.data
  375. weight_type = tensor.tensor_type
  376. name = gguf_to_hf_name_map[tensor.name]
  377. if weight_type.name != "F32":
  378. name = name.replace("weight", "qweight")
  379. param = torch.tensor(weight)
  380. yield name, param
  381. def kv_cache_scales_loader(
  382. filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
  383. model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
  384. """
  385. A simple utility to read in KV cache scaling factors that have been
  386. previously serialized to disk. Used by the model to populate the appropriate
  387. KV cache scaling factors. The serialization should represent a dictionary
  388. whose keys are the TP ranks and values are another dictionary mapping layers
  389. to their KV cache scaling factors.
  390. Keep this function in sync with the output of examples/fp8/extract_scales.py
  391. """
  392. try:
  393. with open(filename) as f:
  394. context = {
  395. "model_type": model_type,
  396. "num_hidden_layers": num_hidden_layers,
  397. "tp_rank": tp_rank,
  398. "tp_size": tp_size,
  399. }
  400. schema_dct = json.load(f)
  401. schema = QuantParamSchema.model_validate(schema_dct,
  402. context=context)
  403. layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
  404. return layer_scales_map.items()
  405. except FileNotFoundError:
  406. logger.error(f"File or directory '{filename}' not found.")
  407. except json.JSONDecodeError:
  408. logger.error(f"Error decoding JSON in file '{filename}'.")
  409. except Exception as e:
  410. logger.error(f"An error occurred while reading '{filename}': {e}")
  411. # This section is reached if and only if any of the excepts are hit
  412. # Return an empty iterable (list) => no KV cache scales are loaded
  413. # which ultimately defaults to 1.0 scales
  414. logger.warning("Defaulting to KV cache scaling factors = 1.0 "
  415. f"for all layers in TP rank {tp_rank} "
  416. "as an error occurred during loading.")
  417. return []
  418. def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
  419. """convert PySafeSlice object from safetensors to torch.Tensor
  420. PySafeSlice object supports indexing, which is done before loading the
  421. actual tensor and can reduce the amount of memory being read into the
  422. memory. However, it does not support more advanced functionalities
  423. like `.view()` or `.t()`. Therefore, if we need to modify the loaded
  424. tensor with these more complicated operators, we need to convert to
  425. tensor first.
  426. """
  427. if not isinstance(x, torch.Tensor):
  428. x = x[:]
  429. return x
  430. def default_weight_loader(param: torch.Tensor,
  431. loaded_weight: torch.Tensor) -> None:
  432. """Default weight loader."""
  433. try:
  434. if param.numel() == 1 and loaded_weight.numel() == 1:
  435. # Sometimes scalar values aren't considered tensors with shapes
  436. # so if both param and loaded_weight are a scalar,
  437. # "broadcast" instead of copy
  438. param.data.fill_(loaded_weight.item())
  439. else:
  440. assert param.size() == loaded_weight.size(), (
  441. f"Attempted to load weight ({loaded_weight.size()}) "
  442. f"into parameter ({param.size()})")
  443. param.data.copy_(loaded_weight)
  444. except Exception:
  445. # NOTE: This exception is added for the purpose of setting breakpoint to
  446. # debug weight loading issues.
  447. raise
  448. def row_parallel_weight_loader(param: torch.Tensor,
  449. loaded_weight: torch.Tensor) -> None:
  450. """Load weights that are row-parallelized."""
  451. tp_rank = get_tensor_model_parallel_rank()
  452. shard_dim = 0 if param.dim() != 1 else None
  453. if shard_dim is not None:
  454. shard_size = param.data.shape[shard_dim]
  455. start_idx = tp_rank * shard_size
  456. loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
  457. return default_weight_loader(param, loaded_weight)
  458. def initialize_dummy_weights(
  459. model: torch.nn.Module,
  460. low: float = -1e-3,
  461. high: float = 1e-3,
  462. seed: int = 1234,
  463. ) -> None:
  464. """Initialize model weights with random values.
  465. The model weights must be randomly initialized for accurate performance
  466. measurements. Additionally, the model weights should not cause NaNs in the
  467. forward pass. We empirically found that initializing the weights with
  468. values between -1e-3 and 1e-3 works well for most models.
  469. We use per-parameter random seed, so that dummy weights are consistent,
  470. even if the model is partitioned across multiple devices. When the seed
  471. is fixed, the random values generated by this function only depends on
  472. the parameter's number of elements and its data type.
  473. """
  474. for param in model.state_dict().values():
  475. if torch.is_floating_point(param):
  476. if current_platform.is_tpu():
  477. # XLA device does not support torch.Generator()
  478. param.uniform_(low, high)
  479. continue
  480. generator = torch.Generator(device=param.data.device)
  481. generator.manual_seed(seed)
  482. if torch.finfo(param.data.dtype).bits < 16:
  483. # uniform_ doesn't support < 16-bit datatypes (FP8)
  484. dtype = param.data.dtype
  485. tmp_param = param.data.to(torch.float16)
  486. tmp_param = tmp_param.uniform_(low, high,
  487. generator=generator).to(dtype)
  488. tmp_param = tmp_param.uniform_(low, high).to(dtype)
  489. param.data.copy_(tmp_param)
  490. else:
  491. param.uniform_(low, high, generator=generator)
  492. def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
  493. """Remap the name of FP8 k/v_scale parameters.
  494. This function handles the remapping of FP8 k/v_scale parameter names.
  495. It detects if the given name ends with a suffix and attempts to remap
  496. it to the expected name format in the model. If the remapped name is not
  497. found in the params_dict, a warning is printed and None is returned.
  498. Args:
  499. name (str): The original loaded checkpoint parameter name.
  500. params_dict (dict): Dictionary containing the model's named parameters.
  501. Returns:
  502. str: The remapped parameter name if successful, or the original name
  503. if no remapping is needed.
  504. None: If the remapped name is not found in params_dict.
  505. """
  506. if name.endswith(".kv_scale"):
  507. print_warning_once(
  508. "DEPRECATED. Found kv_scale in the checkpoint. "
  509. "This format is deprecated in favor of separate k_scale and "
  510. "v_scale tensors and will be removed in a future release. "
  511. "Functionally, we will remap kv_scale to k_scale and duplicate "
  512. "k_scale to v_scale")
  513. # NOTE: we remap the deprecated kv_scale to k_scale
  514. remapped_name = name.replace(".kv_scale", ".attn.k_scale")
  515. if remapped_name not in params_dict:
  516. print_warning_once(
  517. f"Found kv_scale in the checkpoint (e.g. {name}), "
  518. "but not found the expected name in the model "
  519. f"(e.g. {remapped_name}). kv_scale is "
  520. "not loaded.")
  521. return None
  522. return remapped_name
  523. possible_scale_names = [".k_scale", ".v_scale"]
  524. for scale_name in possible_scale_names:
  525. if name.endswith(scale_name):
  526. remapped_name = name.replace(scale_name, f".attn{scale_name}")
  527. if remapped_name not in params_dict:
  528. print_warning_once(
  529. f"Found {scale_name} in the checkpoint (e.g. {name}), "
  530. "but not found the expected name in the model "
  531. f"(e.g. {remapped_name}). {scale_name} is "
  532. "not loaded.")
  533. return None
  534. return remapped_name
  535. # If there were no matches, return the untouched param name
  536. return name