weight_utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. """Utilities for downloading and initializing model weights."""
  2. import fnmatch
  3. import glob
  4. import hashlib
  5. import json
  6. import os
  7. import tempfile
  8. from collections import defaultdict
  9. from typing import Any, Generator, Iterable, List, Optional, Tuple
  10. import filelock
  11. import huggingface_hub.constants
  12. import numpy as np
  13. import torch
  14. from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
  15. from loguru import logger
  16. from safetensors.torch import load_file, safe_open, save_file
  17. from tqdm.auto import tqdm
  18. from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
  19. from aphrodite.common.config import LoadConfig, ModelConfig
  20. from aphrodite.quantization import QuantizationConfig, get_quantization_config
  21. from aphrodite.quantization.schema import QuantParamSchema
  22. # use system-level temp directory for file locks, so that multiple users
  23. # can share the same lock without error.
  24. # lock files in the temp directory will be automatically deleted when the
  25. # system reboots, so users will not complain about annoying lock files
  26. temp_dir = tempfile.gettempdir()
  27. def enable_hf_transfer():
  28. """automatically activates hf_transfer
  29. """
  30. if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
  31. try:
  32. # enable hf hub transfer if available
  33. import hf_transfer # type: ignore # noqa
  34. huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
  35. except ImportError:
  36. pass
  37. enable_hf_transfer()
  38. class DisabledTqdm(tqdm):
  39. def __init__(self, *args, **kwargs):
  40. super().__init__(*args, **kwargs, disable=True)
  41. def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
  42. lock_dir = cache_dir or temp_dir
  43. os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
  44. model_name = model_name_or_path.replace("/", "-")
  45. hash_name = hashlib.sha256(model_name.encode()).hexdigest()
  46. # add hash to avoid conflict with old users' lock files
  47. lock_file_name = hash_name + model_name + ".lock"
  48. # mode 0o666 is required for the filelock to be shared across users
  49. lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
  50. mode=0o666)
  51. return lock
  52. def _shared_pointers(tensors):
  53. ptrs = defaultdict(list)
  54. for k, v in tensors.items():
  55. ptrs[v.data_ptr()].append(k)
  56. failing = []
  57. for _, names in ptrs.items():
  58. if len(names) > 1:
  59. failing.append(names)
  60. return failing
  61. def convert_bin_to_safetensor_file(
  62. pt_filename: str,
  63. sf_filename: str,
  64. ) -> None:
  65. loaded = torch.load(pt_filename, map_location="cpu")
  66. if "state_dict" in loaded:
  67. loaded = loaded["state_dict"]
  68. shared = _shared_pointers(loaded)
  69. for shared_weights in shared:
  70. for name in shared_weights[1:]:
  71. loaded.pop(name)
  72. # For tensors to be contiguous
  73. loaded = {k: v.contiguous() for k, v in loaded.items()}
  74. dirname = os.path.dirname(sf_filename)
  75. os.makedirs(dirname, exist_ok=True)
  76. save_file(loaded, sf_filename, metadata={"format": "pt"})
  77. # check file size
  78. sf_size = os.stat(sf_filename).st_size
  79. pt_size = os.stat(pt_filename).st_size
  80. if (sf_size - pt_size) / pt_size > 0.01:
  81. raise RuntimeError(f"""The file size different is more than 1%:
  82. - {sf_filename}: {sf_size}
  83. - {pt_filename}: {pt_size}
  84. """)
  85. # check if the tensors are the same
  86. reloaded = load_file(sf_filename)
  87. for k in loaded:
  88. pt_tensor = loaded[k]
  89. sf_tensor = reloaded[k]
  90. if not torch.equal(pt_tensor, sf_tensor):
  91. raise RuntimeError(f"The output tensors do not match for key {k}")
  92. # TODO: Move this to other place.
  93. def get_quant_config(model_config: ModelConfig,
  94. load_config: LoadConfig) -> QuantizationConfig:
  95. quant_cls = get_quantization_config(model_config.quantization)
  96. # Read the quantization config from the HF model config, if available.
  97. hf_quant_config = getattr(model_config.hf_config, "quantization_config",
  98. None)
  99. if hf_quant_config is None:
  100. compression_config = getattr(model_config.hf_config,
  101. "compression_config", None)
  102. if compression_config is not None:
  103. hf_quant_config = compression_config.get("quantization_config",
  104. None)
  105. if hf_quant_config is not None:
  106. return quant_cls.from_config(hf_quant_config)
  107. model_name_or_path = model_config.model
  108. is_local = os.path.isdir(model_name_or_path)
  109. if not is_local:
  110. # Download the config files.
  111. with get_lock(model_name_or_path, load_config.download_dir):
  112. hf_folder = snapshot_download(
  113. model_name_or_path,
  114. revision=model_config.revision,
  115. allow_patterns="*.json",
  116. cache_dir=load_config.download_dir,
  117. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  118. tqdm_class=DisabledTqdm,
  119. )
  120. else:
  121. hf_folder = model_name_or_path
  122. possible_config_filenames = quant_cls.get_config_filenames()
  123. # If the quantization config is not found, use the default config.
  124. if not possible_config_filenames:
  125. return quant_cls()
  126. config_files = glob.glob(os.path.join(hf_folder, "*.json"))
  127. quant_config_files = [
  128. f for f in config_files if any(
  129. f.endswith(x) for x in possible_config_filenames)
  130. ]
  131. if len(quant_config_files) == 0:
  132. raise ValueError(
  133. f"Cannot find the config file for {model_config.quantization}")
  134. if len(quant_config_files) > 1:
  135. raise ValueError(
  136. f"Found multiple config files for {model_config.quantization}: "
  137. f"{quant_config_files}")
  138. quant_config_file = quant_config_files[0]
  139. with open(quant_config_file, "r") as f:
  140. config = json.load(f)
  141. return quant_cls.from_config(config)
  142. def download_weights_from_hf(
  143. model_name_or_path: str,
  144. cache_dir: Optional[str],
  145. allow_patterns: List[str],
  146. revision: Optional[str] = None,
  147. ) -> str:
  148. """Download model weights from Hugging Face Hub.
  149. Args:
  150. model_name_or_path (str): The model name or path.
  151. cache_dir (Optional[str]): The cache directory to store the model
  152. weights. If None, will use HF defaults.
  153. allow_patterns (List[str]): The allowed patterns for the
  154. weight files. Files matched by any of the patterns will be
  155. downloaded.
  156. revision (Optional[str]): The revision of the model.
  157. Returns:
  158. str: The path to the downloaded model weights.
  159. """
  160. if not huggingface_hub.constants.HF_HUB_OFFLINE:
  161. # Before we download we look at that is available:
  162. fs = HfFileSystem()
  163. file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
  164. # depending on what is available we download different things
  165. for pattern in allow_patterns:
  166. matching = fnmatch.filter(file_list, pattern)
  167. if len(matching) > 0:
  168. allow_patterns = [pattern]
  169. break
  170. logger.info(f"Using model weights format {allow_patterns}")
  171. # Use file lock to prevent multiple processes from
  172. # downloading the same model weights at the same time.
  173. with get_lock(model_name_or_path, cache_dir):
  174. hf_folder = snapshot_download(
  175. model_name_or_path,
  176. allow_patterns=allow_patterns,
  177. cache_dir=cache_dir,
  178. tqdm_class=DisabledTqdm,
  179. revision=revision,
  180. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  181. )
  182. return hf_folder
  183. def download_safetensors_index_file_from_hf(
  184. model_name_or_path: str,
  185. cache_dir: Optional[str],
  186. revision: Optional[str] = None,
  187. ) -> None:
  188. """Download hf safetensors index file from Hugging Face Hub.
  189. Args:
  190. model_name_or_path (str): The model name or path.
  191. cache_dir (Optional[str]): The cache directory to store the model
  192. weights. If None, will use HF defaults.
  193. revision (Optional[str]): The revision of the model.
  194. """
  195. # Use file lock to prevent multiple processes from
  196. # downloading the same model weights at the same time.
  197. with get_lock(model_name_or_path, cache_dir):
  198. try:
  199. # Download the safetensors index file.
  200. hf_hub_download(
  201. repo_id=model_name_or_path,
  202. filename=SAFE_WEIGHTS_INDEX_NAME,
  203. cache_dir=cache_dir,
  204. revision=revision,
  205. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  206. )
  207. # If file not found on remote or locally, we should not fail since
  208. # only some models will have SAFE_WEIGHTS_INDEX_NAME.
  209. except huggingface_hub.utils.EntryNotFoundError:
  210. logger.info(f"No {SAFE_WEIGHTS_INDEX_NAME} found in remote.")
  211. except huggingface_hub.utils.LocalEntryNotFoundError:
  212. logger.info(f"No {SAFE_WEIGHTS_INDEX_NAME} found in local cache.")
  213. # For models like Mistral-7B-v0.3, there are both sharded
  214. # safetensors files and a consolidated safetensors file.
  215. # Passing both of these to the weight loader functionality breaks.
  216. # So, we use the SAFE_WEIGHTS_INDEX_NAME to
  217. # look up which safetensors files should be used.
  218. def filter_duplicate_safetensors_files(hf_weights_files: List[str],
  219. hf_folder: str) -> List[str]:
  220. # model.safetensors.index.json is a mapping from keys in the
  221. # torch state_dict to safetensors file holding that weight.
  222. index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
  223. if not os.path.isfile(index_file_name):
  224. return hf_weights_files
  225. # Iterate through the weight_map (weight_name: safetensors files)
  226. # to identify weights that we should use.
  227. with open(index_file_name) as index_file:
  228. weight_map = json.load(index_file)["weight_map"]
  229. weight_files_in_index = set()
  230. for weight_name in weight_map:
  231. weight_files_in_index.add(
  232. os.path.join(hf_folder, weight_map[weight_name]))
  233. # Filter out any fields that are not found in the index file.
  234. hf_weights_files = [
  235. f for f in hf_weights_files if f in weight_files_in_index
  236. ]
  237. return hf_weights_files
  238. def filter_files_not_needed_for_inference(
  239. hf_weights_files: List[str]) -> List[str]:
  240. """
  241. Exclude files that are not needed for inference.
  242. See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
  243. """
  244. blacklist = [
  245. "training_args.bin",
  246. "optimizer.bin",
  247. "optimizer.pt",
  248. "scheduler.pt",
  249. "scaler.pt",
  250. ]
  251. hf_weights_files = [
  252. f for f in hf_weights_files
  253. if not any(f.endswith(x) for x in blacklist)
  254. ]
  255. return hf_weights_files
  256. def np_cache_weights_iterator(
  257. model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
  258. hf_weights_files: List[str]
  259. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  260. """Iterate over the weights in the model np files.
  261. Will dump the model weights to numpy files if they are not already dumped.
  262. """
  263. # Convert the model weights from torch tensors to numpy arrays for
  264. # faster loading.
  265. np_folder = os.path.join(hf_folder, "np")
  266. os.makedirs(np_folder, exist_ok=True)
  267. weight_names_file = os.path.join(np_folder, "weight_names.json")
  268. # Use file lock to prevent multiple processes from
  269. # dumping the same model weights to numpy at the same time.
  270. with get_lock(model_name_or_path, cache_dir):
  271. if not os.path.exists(weight_names_file):
  272. weight_names = []
  273. for bin_file in hf_weights_files:
  274. state = torch.load(bin_file, map_location="cpu")
  275. for name, param in state.items():
  276. param_path = os.path.join(np_folder, name)
  277. with open(param_path, "wb") as f:
  278. np.save(f, param.cpu().detach().numpy())
  279. weight_names.append(name)
  280. with open(weight_names_file, "w") as f:
  281. json.dump(weight_names, f)
  282. with open(weight_names_file, "r") as f:
  283. weight_names = json.load(f)
  284. for name in weight_names:
  285. param_path = os.path.join(np_folder, name)
  286. with open(param_path, "rb") as f:
  287. param = np.load(f)
  288. yield name, torch.from_numpy(param)
  289. def safetensors_weights_iterator(
  290. hf_weights_files: List[str]
  291. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  292. """Iterate over the weights in the model safetensor files."""
  293. for st_file in hf_weights_files:
  294. with safe_open(st_file, framework="pt") as f:
  295. for name in f.keys(): # noqa: SIM118
  296. param = f.get_tensor(name)
  297. yield name, param
  298. def pt_weights_iterator(
  299. hf_weights_files: List[str]
  300. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  301. """Iterate over the weights in the model bin/pt files."""
  302. for bin_file in hf_weights_files:
  303. state = torch.load(bin_file, map_location="cpu")
  304. for name, param in state.items():
  305. yield name, param
  306. del state
  307. torch.cuda.empty_cache()
  308. def kv_cache_scales_loader(
  309. filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
  310. model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
  311. """
  312. A simple utility to read in KV cache scaling factors that have been
  313. previously serialized to disk. Used by the model to populate the appropriate
  314. KV cache scaling factors. The serialization should represent a dictionary
  315. whose keys are the TP ranks and values are another dictionary mapping layers
  316. to their KV cache scaling factors.
  317. Keep this function in sync with the output of examples/fp8/extract_scales.py
  318. """
  319. try:
  320. with open(filename) as f:
  321. context = {
  322. "model_type": model_type,
  323. "num_hidden_layers": num_hidden_layers,
  324. "tp_rank": tp_rank,
  325. "tp_size": tp_size,
  326. }
  327. schema_dct = json.load(f)
  328. schema = QuantParamSchema.model_validate(schema_dct,
  329. context=context)
  330. layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
  331. return layer_scales_map.items()
  332. except FileNotFoundError:
  333. logger.error(f"File or directory '{filename}' not found.")
  334. except json.JSONDecodeError:
  335. logger.error(f"Error decoding JSON in file '{filename}'.")
  336. except Exception as e:
  337. logger.error(f"An error occurred while reading '{filename}': {e}")
  338. # This section is reached if and only if any of the excepts are hit
  339. # Return an empty iterable (list) => no KV cache scales are loaded
  340. # which ultimately defaults to 1.0 scales
  341. logger.warning("Defaulting to KV cache scaling factors = 1.0 "
  342. f"for all layers in TP rank {tp_rank} "
  343. "as an error occurred during loading.")
  344. return []
  345. def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
  346. """convert PySafeSlice object from safetensors to torch.Tensor
  347. PySafeSlice object supports indexing, which is done before loading the
  348. actual tensor and can reduce the amount of memory being read into the
  349. memory. However, it does not support more advanced functionalities
  350. like `.view()` or `.t()`. Therefore, if we need to modify the loaded
  351. tensor with these more complicated operators, we need to convert to
  352. tensor first.
  353. """
  354. if not isinstance(x, torch.Tensor):
  355. x = x[:]
  356. return x
  357. def default_weight_loader(param: torch.Tensor,
  358. loaded_weight: torch.Tensor) -> None:
  359. """Default weight loader."""
  360. assert param.size() == loaded_weight.size()
  361. param.data.copy_(loaded_weight)
  362. def initialize_dummy_weights(
  363. model: torch.nn.Module,
  364. low: float = -1e-3,
  365. high: float = 1e-3,
  366. ) -> None:
  367. """Initialize model weights with random values.
  368. The model weights must be randomly initialized for accurate performance
  369. measurements. Additionally, the model weights should not cause NaNs in the
  370. forward pass. We empirically found that initializing the weights with
  371. values between -1e-3 and 1e-3 works well for most models.
  372. """
  373. for param in model.state_dict().values():
  374. if torch.is_floating_point(param):
  375. if torch.finfo(param.data.dtype).bits < 16:
  376. # uniform_ doesn't support < 16-bit datatypes (FP8)
  377. dtype = param.data.dtype
  378. tmp_param = param.data.to(torch.float16)
  379. tmp_param = tmp_param.uniform_(low, high).to(dtype)
  380. param.data.copy_(tmp_param)
  381. else:
  382. param.uniform_(low, high)