weight_utils.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. """Utilities for downloading and initializing model weights."""
  2. import fnmatch
  3. import glob
  4. import hashlib
  5. import json
  6. import os
  7. import tempfile
  8. from collections import defaultdict
  9. from typing import Any, Generator, Iterable, List, Optional, Tuple
  10. import filelock
  11. import huggingface_hub.constants
  12. import numpy as np
  13. import torch
  14. from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
  15. from loguru import logger
  16. from safetensors.torch import load_file, safe_open, save_file
  17. from tqdm.auto import tqdm
  18. from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
  19. from aphrodite.common.config import LoadConfig, ModelConfig
  20. from aphrodite.common.utils import print_warning_once
  21. from aphrodite.quantization import QuantizationConfig, get_quantization_config
  22. from aphrodite.quantization.schema import QuantParamSchema
  23. # use system-level temp directory for file locks, so that multiple users
  24. # can share the same lock without error.
  25. # lock files in the temp directory will be automatically deleted when the
  26. # system reboots, so users will not complain about annoying lock files
  27. temp_dir = tempfile.gettempdir()
  28. def enable_hf_transfer():
  29. """automatically activates hf_transfer
  30. """
  31. if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
  32. try:
  33. # enable hf hub transfer if available
  34. import hf_transfer # type: ignore # noqa
  35. huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
  36. except ImportError:
  37. pass
  38. enable_hf_transfer()
  39. class DisabledTqdm(tqdm):
  40. def __init__(self, *args, **kwargs):
  41. super().__init__(*args, **kwargs, disable=True)
  42. def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
  43. lock_dir = cache_dir or temp_dir
  44. os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
  45. model_name = model_name_or_path.replace("/", "-")
  46. hash_name = hashlib.sha256(model_name.encode()).hexdigest()
  47. # add hash to avoid conflict with old users' lock files
  48. lock_file_name = hash_name + model_name + ".lock"
  49. # mode 0o666 is required for the filelock to be shared across users
  50. lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
  51. mode=0o666)
  52. return lock
  53. def _shared_pointers(tensors):
  54. ptrs = defaultdict(list)
  55. for k, v in tensors.items():
  56. ptrs[v.data_ptr()].append(k)
  57. failing = []
  58. for _, names in ptrs.items():
  59. if len(names) > 1:
  60. failing.append(names)
  61. return failing
  62. def convert_bin_to_safetensor_file(
  63. pt_filename: str,
  64. sf_filename: str,
  65. ) -> None:
  66. loaded = torch.load(pt_filename, map_location="cpu")
  67. if "state_dict" in loaded:
  68. loaded = loaded["state_dict"]
  69. shared = _shared_pointers(loaded)
  70. for shared_weights in shared:
  71. for name in shared_weights[1:]:
  72. loaded.pop(name)
  73. # For tensors to be contiguous
  74. loaded = {k: v.contiguous() for k, v in loaded.items()}
  75. dirname = os.path.dirname(sf_filename)
  76. os.makedirs(dirname, exist_ok=True)
  77. save_file(loaded, sf_filename, metadata={"format": "pt"})
  78. # check file size
  79. sf_size = os.stat(sf_filename).st_size
  80. pt_size = os.stat(pt_filename).st_size
  81. if (sf_size - pt_size) / pt_size > 0.01:
  82. raise RuntimeError(f"""The file size different is more than 1%:
  83. - {sf_filename}: {sf_size}
  84. - {pt_filename}: {pt_size}
  85. """)
  86. # check if the tensors are the same
  87. reloaded = load_file(sf_filename)
  88. for k in loaded:
  89. pt_tensor = loaded[k]
  90. sf_tensor = reloaded[k]
  91. if not torch.equal(pt_tensor, sf_tensor):
  92. raise RuntimeError(f"The output tensors do not match for key {k}")
  93. # TODO: Move this to other place.
  94. def get_quant_config(model_config: ModelConfig,
  95. load_config: LoadConfig) -> QuantizationConfig:
  96. quant_cls = get_quantization_config(model_config.quantization)
  97. # Read the quantization config from the HF model config, if available.
  98. hf_quant_config = getattr(model_config.hf_config, "quantization_config",
  99. None)
  100. if hf_quant_config is None:
  101. # compressed-tensors uses a compressions_config
  102. hf_quant_config = getattr(model_config.hf_config, "compression_config",
  103. None)
  104. if hf_quant_config is not None:
  105. return quant_cls.from_config(hf_quant_config)
  106. # In case of bitsandbytes/QLoRA, get quant config from the adapter model.
  107. if model_config.quantization == "bitsandbytes":
  108. if (not load_config.model_loader_extra_config
  109. or "qlora_adapter_name_or_path"
  110. not in load_config.model_loader_extra_config):
  111. return quant_cls.from_config({"adapter_name_or_path": ""})
  112. model_name_or_path = load_config.model_loader_extra_config[
  113. "qlora_adapter_name_or_path"]
  114. else:
  115. model_name_or_path = model_config.model
  116. is_local = os.path.isdir(model_name_or_path)
  117. if not is_local:
  118. # Download the config files.
  119. with get_lock(model_name_or_path, load_config.download_dir):
  120. hf_folder = snapshot_download(
  121. model_name_or_path,
  122. revision=model_config.revision,
  123. allow_patterns="*.json",
  124. cache_dir=load_config.download_dir,
  125. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  126. tqdm_class=DisabledTqdm,
  127. )
  128. else:
  129. hf_folder = model_name_or_path
  130. possible_config_filenames = quant_cls.get_config_filenames()
  131. # If the quantization config is not found, use the default config.
  132. if not possible_config_filenames:
  133. return quant_cls()
  134. config_files = glob.glob(os.path.join(hf_folder, "*.json"))
  135. quant_config_files = [
  136. f for f in config_files if any(
  137. f.endswith(x) for x in possible_config_filenames)
  138. ]
  139. if len(quant_config_files) == 0:
  140. raise ValueError(
  141. f"Cannot find the config file for {model_config.quantization}")
  142. if len(quant_config_files) > 1:
  143. raise ValueError(
  144. f"Found multiple config files for {model_config.quantization}: "
  145. f"{quant_config_files}")
  146. quant_config_file = quant_config_files[0]
  147. with open(quant_config_file, "r") as f:
  148. config = json.load(f)
  149. if model_config.quantization == "bitsandbytes":
  150. config["adapter_name_or_path"] = model_name_or_path
  151. return quant_cls.from_config(config)
  152. def download_weights_from_hf(
  153. model_name_or_path: str,
  154. cache_dir: Optional[str],
  155. allow_patterns: List[str],
  156. revision: Optional[str] = None,
  157. ) -> str:
  158. """Download model weights from Hugging Face Hub.
  159. Args:
  160. model_name_or_path (str): The model name or path.
  161. cache_dir (Optional[str]): The cache directory to store the model
  162. weights. If None, will use HF defaults.
  163. allow_patterns (List[str]): The allowed patterns for the
  164. weight files. Files matched by any of the patterns will be
  165. downloaded.
  166. revision (Optional[str]): The revision of the model.
  167. Returns:
  168. str: The path to the downloaded model weights.
  169. """
  170. if not huggingface_hub.constants.HF_HUB_OFFLINE:
  171. # Before we download we look at that is available:
  172. fs = HfFileSystem()
  173. file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
  174. # depending on what is available we download different things
  175. for pattern in allow_patterns:
  176. matching = fnmatch.filter(file_list, pattern)
  177. if len(matching) > 0:
  178. allow_patterns = [pattern]
  179. break
  180. logger.info(f"Using model weights format {allow_patterns}")
  181. # Use file lock to prevent multiple processes from
  182. # downloading the same model weights at the same time.
  183. with get_lock(model_name_or_path, cache_dir):
  184. hf_folder = snapshot_download(
  185. model_name_or_path,
  186. allow_patterns=allow_patterns,
  187. cache_dir=cache_dir,
  188. tqdm_class=DisabledTqdm,
  189. revision=revision,
  190. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  191. )
  192. return hf_folder
  193. def download_safetensors_index_file_from_hf(
  194. model_name_or_path: str,
  195. cache_dir: Optional[str],
  196. revision: Optional[str] = None,
  197. ) -> None:
  198. """Download hf safetensors index file from Hugging Face Hub.
  199. Args:
  200. model_name_or_path (str): The model name or path.
  201. cache_dir (Optional[str]): The cache directory to store the model
  202. weights. If None, will use HF defaults.
  203. revision (Optional[str]): The revision of the model.
  204. """
  205. # Use file lock to prevent multiple processes from
  206. # downloading the same model weights at the same time.
  207. with get_lock(model_name_or_path, cache_dir):
  208. try:
  209. # Download the safetensors index file.
  210. hf_hub_download(
  211. repo_id=model_name_or_path,
  212. filename=SAFE_WEIGHTS_INDEX_NAME,
  213. cache_dir=cache_dir,
  214. revision=revision,
  215. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  216. )
  217. # If file not found on remote or locally, we should not fail since
  218. # only some models will have SAFE_WEIGHTS_INDEX_NAME.
  219. except huggingface_hub.utils.EntryNotFoundError:
  220. logger.info(f"No {SAFE_WEIGHTS_INDEX_NAME} found in remote.")
  221. except huggingface_hub.utils.LocalEntryNotFoundError:
  222. logger.info(f"No {SAFE_WEIGHTS_INDEX_NAME} found in local cache.")
  223. # For models like Mistral-7B-v0.3, there are both sharded
  224. # safetensors files and a consolidated safetensors file.
  225. # Passing both of these to the weight loader functionality breaks.
  226. # So, we use the SAFE_WEIGHTS_INDEX_NAME to
  227. # look up which safetensors files should be used.
  228. def filter_duplicate_safetensors_files(hf_weights_files: List[str],
  229. hf_folder: str) -> List[str]:
  230. # model.safetensors.index.json is a mapping from keys in the
  231. # torch state_dict to safetensors file holding that weight.
  232. index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
  233. if not os.path.isfile(index_file_name):
  234. return hf_weights_files
  235. # Iterate through the weight_map (weight_name: safetensors files)
  236. # to identify weights that we should use.
  237. with open(index_file_name) as index_file:
  238. weight_map = json.load(index_file)["weight_map"]
  239. weight_files_in_index = set()
  240. for weight_name in weight_map:
  241. weight_files_in_index.add(
  242. os.path.join(hf_folder, weight_map[weight_name]))
  243. # Filter out any fields that are not found in the index file.
  244. hf_weights_files = [
  245. f for f in hf_weights_files if f in weight_files_in_index
  246. ]
  247. return hf_weights_files
  248. def filter_files_not_needed_for_inference(
  249. hf_weights_files: List[str]) -> List[str]:
  250. """
  251. Exclude files that are not needed for inference.
  252. See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
  253. """
  254. blacklist = [
  255. "training_args.bin",
  256. "optimizer.bin",
  257. "optimizer.pt",
  258. "scheduler.pt",
  259. "scaler.pt",
  260. ]
  261. hf_weights_files = [
  262. f for f in hf_weights_files
  263. if not any(f.endswith(x) for x in blacklist)
  264. ]
  265. return hf_weights_files
  266. def np_cache_weights_iterator(
  267. model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
  268. hf_weights_files: List[str]
  269. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  270. """Iterate over the weights in the model np files.
  271. Will dump the model weights to numpy files if they are not already dumped.
  272. """
  273. # Convert the model weights from torch tensors to numpy arrays for
  274. # faster loading.
  275. np_folder = os.path.join(hf_folder, "np")
  276. os.makedirs(np_folder, exist_ok=True)
  277. weight_names_file = os.path.join(np_folder, "weight_names.json")
  278. # Use file lock to prevent multiple processes from
  279. # dumping the same model weights to numpy at the same time.
  280. with get_lock(model_name_or_path, cache_dir):
  281. if not os.path.exists(weight_names_file):
  282. weight_names = []
  283. for bin_file in hf_weights_files:
  284. state = torch.load(bin_file, map_location="cpu")
  285. for name, param in state.items():
  286. param_path = os.path.join(np_folder, name)
  287. with open(param_path, "wb") as f:
  288. np.save(f, param.cpu().detach().numpy())
  289. weight_names.append(name)
  290. with open(weight_names_file, "w") as f:
  291. json.dump(weight_names, f)
  292. with open(weight_names_file, "r") as f:
  293. weight_names = json.load(f)
  294. for name in weight_names:
  295. param_path = os.path.join(np_folder, name)
  296. with open(param_path, "rb") as f:
  297. param = np.load(f)
  298. yield name, torch.from_numpy(param)
  299. def safetensors_weights_iterator(
  300. hf_weights_files: List[str]
  301. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  302. """Iterate over the weights in the model safetensor files."""
  303. for st_file in hf_weights_files:
  304. with safe_open(st_file, framework="pt") as f:
  305. for name in f.keys(): # noqa: SIM118
  306. param = f.get_tensor(name)
  307. yield name, param
  308. def pt_weights_iterator(
  309. hf_weights_files: List[str]
  310. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  311. """Iterate over the weights in the model bin/pt files."""
  312. for bin_file in hf_weights_files:
  313. state = torch.load(bin_file, map_location="cpu")
  314. for name, param in state.items():
  315. yield name, param
  316. del state
  317. torch.cuda.empty_cache()
  318. def kv_cache_scales_loader(
  319. filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
  320. model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
  321. """
  322. A simple utility to read in KV cache scaling factors that have been
  323. previously serialized to disk. Used by the model to populate the appropriate
  324. KV cache scaling factors. The serialization should represent a dictionary
  325. whose keys are the TP ranks and values are another dictionary mapping layers
  326. to their KV cache scaling factors.
  327. Keep this function in sync with the output of examples/fp8/extract_scales.py
  328. """
  329. try:
  330. with open(filename) as f:
  331. context = {
  332. "model_type": model_type,
  333. "num_hidden_layers": num_hidden_layers,
  334. "tp_rank": tp_rank,
  335. "tp_size": tp_size,
  336. }
  337. schema_dct = json.load(f)
  338. schema = QuantParamSchema.model_validate(schema_dct,
  339. context=context)
  340. layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
  341. return layer_scales_map.items()
  342. except FileNotFoundError:
  343. logger.error(f"File or directory '{filename}' not found.")
  344. except json.JSONDecodeError:
  345. logger.error(f"Error decoding JSON in file '{filename}'.")
  346. except Exception as e:
  347. logger.error(f"An error occurred while reading '{filename}': {e}")
  348. # This section is reached if and only if any of the excepts are hit
  349. # Return an empty iterable (list) => no KV cache scales are loaded
  350. # which ultimately defaults to 1.0 scales
  351. logger.warning("Defaulting to KV cache scaling factors = 1.0 "
  352. f"for all layers in TP rank {tp_rank} "
  353. "as an error occurred during loading.")
  354. return []
  355. def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
  356. """convert PySafeSlice object from safetensors to torch.Tensor
  357. PySafeSlice object supports indexing, which is done before loading the
  358. actual tensor and can reduce the amount of memory being read into the
  359. memory. However, it does not support more advanced functionalities
  360. like `.view()` or `.t()`. Therefore, if we need to modify the loaded
  361. tensor with these more complicated operators, we need to convert to
  362. tensor first.
  363. """
  364. if not isinstance(x, torch.Tensor):
  365. x = x[:]
  366. return x
  367. def default_weight_loader(param: torch.Tensor,
  368. loaded_weight: torch.Tensor) -> None:
  369. """Default weight loader."""
  370. assert param.size() == loaded_weight.size()
  371. param.data.copy_(loaded_weight)
  372. def initialize_dummy_weights(
  373. model: torch.nn.Module,
  374. low: float = -1e-3,
  375. high: float = 1e-3,
  376. seed: int = 1234,
  377. ) -> None:
  378. """Initialize model weights with random values.
  379. The model weights must be randomly initialized for accurate performance
  380. measurements. Additionally, the model weights should not cause NaNs in the
  381. forward pass. We empirically found that initializing the weights with
  382. values between -1e-3 and 1e-3 works well for most models.
  383. We use per-parameter random seed, so that dummy weights are consistent,
  384. even if the model is partitioned across multiple devices. When the seed
  385. is fixed, the random values generated by this function only depends on
  386. the parameter's number of elements and its data type.
  387. """
  388. for param in model.state_dict().values():
  389. if torch.is_floating_point(param):
  390. generator = torch.Generator(device=param.data.device)
  391. generator.manual_seed(seed)
  392. if torch.finfo(param.data.dtype).bits < 16:
  393. # uniform_ doesn't support < 16-bit datatypes (FP8)
  394. dtype = param.data.dtype
  395. tmp_param = param.data.to(torch.float16)
  396. tmp_param = tmp_param.uniform_(low, high,
  397. generator=generator).to(dtype)
  398. tmp_param = tmp_param.uniform_(low, high).to(dtype)
  399. param.data.copy_(tmp_param)
  400. else:
  401. param.uniform_(low, high, generator=generator)
  402. def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
  403. """Remap the name of FP8 k/v_scale parameters.
  404. This function handles the remapping of FP8 k/v_scale parameter names.
  405. It detects if the given name ends with a suffix and attempts to remap
  406. it to the expected name format in the model. If the remapped name is not
  407. found in the params_dict, a warning is printed and None is returned.
  408. Args:
  409. name (str): The original loaded checkpoint parameter name.
  410. params_dict (dict): Dictionary containing the model's named parameters.
  411. Returns:
  412. str: The remapped parameter name if successful, or the original name
  413. if no remapping is needed.
  414. None: If the remapped name is not found in params_dict.
  415. """
  416. if name.endswith(".kv_scale"):
  417. print_warning_once(
  418. "DEPRECATED. Found kv_scale in the checkpoint. "
  419. "This format is deprecated in favor of separate k_scale and "
  420. "v_scale tensors and will be removed in a future release. "
  421. "Functionally, we will remap kv_scale to k_scale and duplicate "
  422. "k_scale to v_scale")
  423. # NOTE: we remap the deprecated kv_scale to k_scale
  424. remapped_name = name.replace(".kv_scale", ".attn.k_scale")
  425. if remapped_name not in params_dict:
  426. print_warning_once(
  427. f"Found kv_scale in the checkpoint (e.g. {name}), "
  428. "but not found the expected name in the model "
  429. f"(e.g. {remapped_name}). kv_scale is "
  430. "not loaded.")
  431. return None
  432. return remapped_name
  433. possible_scale_names = [".k_scale", ".v_scale"]
  434. for scale_name in possible_scale_names:
  435. if name.endswith(scale_name):
  436. remapped_name = name.replace(scale_name, f".attn{scale_name}")
  437. if remapped_name not in params_dict:
  438. print_warning_once(
  439. f"Found {scale_name} in the checkpoint (e.g. {name}), "
  440. "but not found the expected name in the model "
  441. f"(e.g. {remapped_name}). {scale_name} is "
  442. "not loaded.")
  443. return None
  444. return remapped_name
  445. # If there were no matches, return the untouched param name
  446. return name