hf_downloader.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. """Utilities for downloading and initializing model weights."""
  2. import filelock
  3. import glob
  4. import fnmatch
  5. import json
  6. import os
  7. from collections import defaultdict
  8. from typing import Any, Iterator, List, Optional, Tuple
  9. from huggingface_hub import snapshot_download, HfFileSystem
  10. import numpy as np
  11. from safetensors.torch import load_file, save_file, safe_open
  12. import torch
  13. from transformers import PretrainedConfig
  14. from tqdm.auto import tqdm
  15. from aphrodite.common.logger import init_logger
  16. from aphrodite.modeling.layers.quantization import (get_quantization_config,
  17. QuantizationConfig)
  18. logger = init_logger(__name__)
  19. class Disabledtqdm(tqdm): # pylint: disable=inconsistent-mro
  20. def __init__(self, *args, **kwargs):
  21. super().__init__(*args, **kwargs, disable=True)
  22. def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
  23. lock_dir = cache_dir if cache_dir is not None else "/tmp"
  24. lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
  25. lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
  26. return lock
  27. def _shared_pointers(tensors):
  28. ptrs = defaultdict(list)
  29. for k, v in tensors.items():
  30. ptrs[v.data_ptr()].append(k)
  31. failing = []
  32. for _, names in ptrs.items():
  33. if len(names) > 1:
  34. failing.append(names)
  35. return failing
  36. def convert_bin_to_safetensor_file(
  37. pt_filename: str,
  38. sf_filename: str,
  39. ) -> None:
  40. loaded = torch.load(pt_filename, map_location="cpu")
  41. if "state_dict" in loaded:
  42. loaded = loaded["state_dict"]
  43. shared = _shared_pointers(loaded)
  44. for shared_weights in shared:
  45. for name in shared_weights[1:]:
  46. loaded.pop(name)
  47. # For tensors to be contiguous
  48. loaded = {k: v.contiguous() for k, v in loaded.items()}
  49. dirname = os.path.dirname(sf_filename)
  50. os.makedirs(dirname, exist_ok=True)
  51. save_file(loaded, sf_filename, metadata={"format": "pt"})
  52. # check file size
  53. sf_size = os.stat(sf_filename).st_size
  54. pt_size = os.stat(pt_filename).st_size
  55. if (sf_size - pt_size) / pt_size > 0.01:
  56. raise RuntimeError(f"""The file size different is more than 1%:
  57. - {sf_filename}: {sf_size}
  58. - {pt_filename}: {pt_size}
  59. """)
  60. # check if the tensors are the same
  61. reloaded = load_file(sf_filename)
  62. for k in loaded:
  63. pt_tensor = loaded[k]
  64. sf_tensor = reloaded[k]
  65. if not torch.equal(pt_tensor, sf_tensor):
  66. raise RuntimeError(f"The output tensors do not match for key {k}")
  67. # TODO: Move this to another place.
  68. def get_quant_config(
  69. quantization: str,
  70. model_name_or_path: str,
  71. hf_config: PretrainedConfig,
  72. cache_dir: Optional[str] = None,
  73. ) -> QuantizationConfig:
  74. quant_cls = get_quantization_config(quantization)
  75. # No need for extra config
  76. if quantization == "gguf":
  77. return quant_cls()
  78. # Read the quantization config from the HF model config, if available.
  79. hf_quant_config = getattr(hf_config, "quantization_config", None)
  80. if hf_quant_config is not None:
  81. return quant_cls.from_config(hf_quant_config)
  82. is_local = os.path.isdir(model_name_or_path)
  83. if not is_local:
  84. # Download the config files.
  85. with get_lock(model_name_or_path, cache_dir):
  86. hf_folder = snapshot_download(model_name_or_path,
  87. allow_patterns="*.json",
  88. cache_dir=cache_dir,
  89. tqdm_class=Disabledtqdm)
  90. else:
  91. hf_folder = model_name_or_path
  92. config_files = glob.glob(os.path.join(hf_folder, "*.json"))
  93. quant_config_files = [
  94. f for f in config_files if any(
  95. f.endswith(x) for x in quant_cls.get_config_filenames())
  96. ]
  97. if len(quant_config_files) == 0:
  98. raise ValueError(f"Cannot find the config file for {quantization}")
  99. if len(quant_config_files) > 1:
  100. raise ValueError(f"Found multiple config files for {quantization}: "
  101. f"{quant_config_files}")
  102. quant_config_file = quant_config_files[0]
  103. with open(quant_config_file, "r") as f:
  104. config = json.load(f)
  105. return quant_cls.from_config(config)
  106. def prepare_hf_model_weights(
  107. model_name_or_path: str,
  108. cache_dir: Optional[str] = None,
  109. load_format: str = "auto",
  110. fall_back_to_pt: bool = True,
  111. revision: Optional[str] = None,
  112. ) -> Tuple[str, List[str], bool]:
  113. # Download model weights from huggingface.
  114. is_local = os.path.isdir(model_name_or_path)
  115. use_safetensors = False
  116. # Some quantized models use .pt files for storing the weights.
  117. if load_format == "auto":
  118. allow_patterns = ["*.safetensors", "*.bin"]
  119. elif load_format == "safetensors":
  120. use_safetensors = True
  121. allow_patterns = ["*.safetensors"]
  122. elif load_format == "pt":
  123. allow_patterns = ["*.pt"]
  124. elif load_format == "npcache":
  125. allow_patterns = ["*.bin"]
  126. else:
  127. raise ValueError(f"Unknown load_format: {load_format}")
  128. if fall_back_to_pt:
  129. allow_patterns += ["*.pt"]
  130. if not is_local:
  131. # Before we download we look at that is available:
  132. fs = HfFileSystem()
  133. file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
  134. # depending on what is available we download different things
  135. for pattern in allow_patterns:
  136. matching = fnmatch.filter(file_list, pattern)
  137. if len(matching) > 0:
  138. allow_patterns = [pattern]
  139. break
  140. logger.info(f"Downloading model weights {allow_patterns}")
  141. # Use file lock to prevent multiple processes from
  142. # downloading the same model weights at the same time.
  143. with get_lock(model_name_or_path, cache_dir):
  144. hf_folder = snapshot_download(model_name_or_path,
  145. allow_patterns=allow_patterns,
  146. cache_dir=cache_dir,
  147. tqdm_class=Disabledtqdm,
  148. revision=revision)
  149. else:
  150. hf_folder = model_name_or_path
  151. hf_weights_files: List[str] = []
  152. for pattern in allow_patterns:
  153. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  154. if len(hf_weights_files) > 0:
  155. if pattern == "*.safetensors":
  156. use_safetensors = True
  157. break
  158. if not use_safetensors:
  159. # Exclude files that are not needed for inference.
  160. # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
  161. blacklist = [
  162. "training_args.bin",
  163. "optimizer.bin",
  164. "optimizer.pt",
  165. "scheduler.pt",
  166. "scaler.pt",
  167. "trainer_state.json",
  168. ]
  169. hf_weights_files = [
  170. f for f in hf_weights_files
  171. if not any(f.endswith(x) for x in blacklist)
  172. ]
  173. if len(hf_weights_files) == 0:
  174. raise RuntimeError(
  175. f"Cannot find any model weights with `{model_name_or_path}`")
  176. return hf_folder, hf_weights_files, use_safetensors
  177. def hf_model_weights_iterator(
  178. model_name_or_path: str,
  179. cache_dir: Optional[str] = None,
  180. load_format: str = "auto",
  181. revision: Optional[str] = None,
  182. fall_back_to_pt: Optional[bool] = True,
  183. ) -> Iterator[Tuple[str, torch.Tensor]]:
  184. hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
  185. model_name_or_path,
  186. cache_dir=cache_dir,
  187. load_format=load_format,
  188. fall_back_to_pt=fall_back_to_pt,
  189. revision=revision)
  190. if load_format == "npcache":
  191. # Currently np_cache only support *.bin checkpoints
  192. assert use_safetensors is False
  193. # Convert the model weights from torch tensors to numpy arrays for
  194. # faster loading.
  195. np_folder = os.path.join(hf_folder, "np")
  196. os.makedirs(np_folder, exist_ok=True)
  197. weight_names_file = os.path.join(np_folder, "weight_names.json")
  198. # Use file lock to prevent multiple processes from
  199. # dumping the same model weights to numpy at the same time.
  200. with get_lock(model_name_or_path, cache_dir):
  201. if not os.path.exists(weight_names_file):
  202. weight_names = []
  203. for bin_file in hf_weights_files:
  204. state = torch.load(bin_file, map_location="cpu")
  205. for name, param in state.items():
  206. param_path = os.path.join(np_folder, name)
  207. with open(param_path, "wb") as f:
  208. np.save(f, param.cpu().detach().numpy())
  209. weight_names.append(name)
  210. with open(weight_names_file, "w") as f:
  211. json.dump(weight_names, f)
  212. with open(weight_names_file, "r") as f:
  213. weight_names = json.load(f)
  214. for name in weight_names:
  215. param_path = os.path.join(np_folder, name)
  216. with open(param_path, "rb") as f:
  217. param = np.load(f)
  218. yield name, torch.from_numpy(param)
  219. elif use_safetensors:
  220. for st_file in hf_weights_files:
  221. with safe_open(st_file, framework="pt") as f:
  222. for name in f.keys(): # noqa: SIM118
  223. param = f.get_tensor(name)
  224. yield name, param
  225. else:
  226. for bin_file in hf_weights_files:
  227. state = torch.load(bin_file, map_location="cpu")
  228. for name, param in state.items():
  229. yield name, param
  230. del state
  231. torch.cuda.empty_cache()
  232. def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
  233. """convert PySafeSlice object from safetensors to torch.Tensor
  234. PySafeSlice object supports indexing, which is done before loading the
  235. actual tensor and can reduce the amount of memory being read into the
  236. memory. However, it does not support more advanced functionalities
  237. like `.view()` or `.t()`. Therefore, if we need to modify the loaded
  238. tensor with these more complicated operators, we need to convert to
  239. tensor first.
  240. """
  241. if not isinstance(x, torch.Tensor):
  242. x = x[:]
  243. return x
  244. def default_weight_loader(param: torch.Tensor,
  245. loaded_weight: torch.Tensor) -> None:
  246. """Default weight loader."""
  247. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  248. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  249. assert param.size() == loaded_weight.size()
  250. param.data.copy_(loaded_weight)
  251. def initialize_dummy_weights(
  252. model: torch.nn.Module,
  253. low: float = -1e-3,
  254. high: float = 1e-3,
  255. ) -> None:
  256. """Initialize model weights with random values.
  257. The model weights must be randomly initialized for accurate performance
  258. measurements. Additionally, the model weights should not cause NaNs in the
  259. forward pass. We empirically found that initializing the weights with
  260. values between -1e-3 and 1e-3 works well for most models.
  261. """
  262. for param in model.state_dict().values():
  263. if torch.is_floating_point(param):
  264. param.data.uniform_(low, high)