1
0

hf_downloader.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. """Utilities for downloading and initializing model weights."""
  2. import filelock
  3. import glob
  4. import json
  5. import os
  6. from collections import defaultdict
  7. from typing import Iterator, List, Optional, Tuple, Any
  8. from huggingface_hub import snapshot_download
  9. from safetensors.torch import load_file, save_file, safe_open
  10. import numpy as np
  11. import torch
  12. from tqdm.auto import tqdm
  13. from transformers import PretrainedConfig
  14. from aphrodite.common.logger import init_logger
  15. from aphrodite.modeling.quantization_utils import get_quant_class
  16. from aphrodite.modeling.quantization_utils.base import QuantizationConfig
  17. logger = init_logger(__name__)
  18. class Disabledtqdm(tqdm): # pylint: disable=inconsistent-mro
  19. def __init__(self, *args, **kwargs):
  20. super().__init__(*args, **kwargs, disable=True)
  21. def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
  22. lock_dir = cache_dir if cache_dir is not None else "/tmp"
  23. lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
  24. lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
  25. return lock
  26. def _shared_pointers(tensors):
  27. ptrs = defaultdict(list)
  28. for k, v in tensors.items():
  29. ptrs[v.data_ptr()].append(k)
  30. failing = []
  31. for _, names in ptrs.items():
  32. if len(names) > 1:
  33. failing.append(names)
  34. return failing
  35. def convert_bin_to_safetensor_file(
  36. pt_filename: str,
  37. sf_filename: str,
  38. ) -> None:
  39. loaded = torch.load(pt_filename, map_location="cpu")
  40. if "state_dict" in loaded:
  41. loaded = loaded["state_dict"]
  42. shared = _shared_pointers(loaded)
  43. for shared_weights in shared:
  44. for name in shared_weights[1:]:
  45. loaded.pop(name)
  46. # For tensors to be contiguous
  47. loaded = {k: v.contiguous() for k, v in loaded.items()}
  48. dirname = os.path.dirname(sf_filename)
  49. os.makedirs(dirname, exist_ok=True)
  50. save_file(loaded, sf_filename, metadata={"format": "pt"})
  51. # check file size
  52. sf_size = os.stat(sf_filename).st_size
  53. pt_size = os.stat(pt_filename).st_size
  54. if (sf_size - pt_size) / pt_size > 0.01:
  55. raise RuntimeError(f"""The file size different is more than 1%:
  56. - {sf_filename}: {sf_size}
  57. - {pt_filename}: {pt_size}
  58. """)
  59. # check if the tensors are the same
  60. reloaded = load_file(sf_filename)
  61. for k in loaded:
  62. pt_tensor = loaded[k]
  63. sf_tensor = reloaded[k]
  64. if not torch.equal(pt_tensor, sf_tensor):
  65. raise RuntimeError(f"The output tensors do not match for key {k}")
  66. # TODO: Move this to other place.
  67. def get_quant_config(
  68. quantization: str,
  69. model_name_or_path: str,
  70. hf_config: PretrainedConfig,
  71. cache_dir: Optional[str] = None,
  72. ) -> QuantizationConfig:
  73. if quantization == "gptq" and hasattr(hf_config, "quantization_config"):
  74. config = hf_config.quantization_config
  75. return get_quant_class(quantization).from_config(config)
  76. is_local = os.path.isdir(model_name_or_path)
  77. if not is_local:
  78. # Download the config files.
  79. with get_lock(model_name_or_path, cache_dir):
  80. hf_folder = snapshot_download(model_name_or_path,
  81. allow_patterns="*.json",
  82. cache_dir=cache_dir,
  83. tqdm_class=Disabledtqdm)
  84. else:
  85. hf_folder = model_name_or_path
  86. config_files = glob.glob(os.path.join(hf_folder, "*.json"))
  87. quant_cls = get_quant_class(quantization)
  88. quant_config_files = [
  89. f for f in config_files if any(
  90. f.endswith(x) for x in quant_cls.get_config_filenames())
  91. ]
  92. if len(quant_config_files) == 0:
  93. raise ValueError(f"Cannot find the config file for {quantization}")
  94. if len(quant_config_files) > 1:
  95. raise ValueError(f"Found multiple config files for {quantization}: "
  96. f"{quant_config_files}")
  97. quant_config_file = quant_config_files[0]
  98. with open(quant_config_file, "r") as f:
  99. config = json.load(f)
  100. return quant_cls.from_config(config)
  101. def prepare_hf_model_weights(
  102. model_name_or_path: str,
  103. cache_dir: Optional[str] = None,
  104. use_safetensors: bool = False,
  105. fall_back_to_pt: bool = True,
  106. revision: Optional[str] = None,
  107. ) -> Tuple[str, List[str], bool]:
  108. # Download model weights from huggingface.
  109. is_local = os.path.isdir(model_name_or_path)
  110. if use_safetensors:
  111. allow_patterns = ["*.safetensors"]
  112. else:
  113. # Some quantized models use .pt files for storing the weights.
  114. allow_patterns = ["*.bin", "*.pt"]
  115. if not is_local:
  116. # Use file lock to prevent multiple processes from
  117. # downloading the same model weights at the same time.
  118. with get_lock(model_name_or_path, cache_dir):
  119. hf_folder = snapshot_download(model_name_or_path,
  120. allow_patterns=allow_patterns,
  121. cache_dir=cache_dir,
  122. tqdm_class=Disabledtqdm,
  123. revision=revision)
  124. else:
  125. hf_folder = model_name_or_path
  126. hf_weights_files: List[str] = []
  127. for pattern in allow_patterns:
  128. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  129. if not use_safetensors:
  130. # exclude unneeded files
  131. blacklist = [
  132. "training_args.bin",
  133. "optimizer.bin",
  134. "optimizer.pt",
  135. "scheduler.pt",
  136. "scaler.pt",
  137. "trainer_state.json",
  138. ]
  139. hf_weights_files = [
  140. f for f in hf_weights_files
  141. if not any(f.endswith(x) for x in blacklist)
  142. ]
  143. if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
  144. return prepare_hf_model_weights(model_name_or_path,
  145. cache_dir=cache_dir,
  146. use_safetensors=False,
  147. fall_back_to_pt=False,
  148. revision=revision)
  149. if len(hf_weights_files) == 0:
  150. raise RuntimeError(
  151. f"Cannot find any model weights with `{model_name_or_path}`")
  152. return hf_folder, hf_weights_files, use_safetensors
  153. def hf_model_weights_iterator(
  154. model_name_or_path: str,
  155. cache_dir: Optional[str] = None,
  156. load_format: str = "auto",
  157. revision: Optional[str] = None,
  158. ) -> Iterator[Tuple[str, torch.Tensor]]:
  159. use_safetensors = False
  160. use_np_cache = False
  161. fall_back_to_pt = False
  162. if load_format == "auto":
  163. use_safetensors = True
  164. fall_back_to_pt = True
  165. elif load_format == "safetensors":
  166. use_safetensors = True
  167. elif load_format == "pt":
  168. pass
  169. elif load_format == "npcache":
  170. use_np_cache = True
  171. else:
  172. raise ValueError(f"Unknown load_format: {load_format}")
  173. hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
  174. model_name_or_path,
  175. cache_dir=cache_dir,
  176. use_safetensors=use_safetensors,
  177. fall_back_to_pt=fall_back_to_pt,
  178. revision=revision)
  179. if use_np_cache:
  180. # Currently np_cache only support *.bin checkpoints
  181. assert use_safetensors is False
  182. # Convert the model weights from torch tensors to numpy arrays for
  183. # faster loading.
  184. np_folder = os.path.join(hf_folder, "np")
  185. os.makedirs(np_folder, exist_ok=True)
  186. weight_names_file = os.path.join(np_folder, "weight_names.json")
  187. # Use file lock to prevent multiple processes from
  188. # dumping the same model weights to numpy at the same time.
  189. with get_lock(model_name_or_path, cache_dir):
  190. if not os.path.exists(weight_names_file):
  191. weight_names = []
  192. for bin_file in hf_weights_files:
  193. state = torch.load(bin_file, map_location="cpu")
  194. for name, param in state.items():
  195. param_path = os.path.join(np_folder, name)
  196. with open(param_path, "wb") as f:
  197. np.save(f, param.cpu().detach().numpy())
  198. weight_names.append(name)
  199. with open(weight_names_file, "w") as f:
  200. json.dump(weight_names, f)
  201. with open(weight_names_file, "r") as f:
  202. weight_names = json.load(f)
  203. for name in weight_names:
  204. param_path = os.path.join(np_folder, name)
  205. with open(param_path, "rb") as f:
  206. param = np.load(f)
  207. yield name, torch.from_numpy(param)
  208. elif use_safetensors:
  209. for st_file in hf_weights_files:
  210. with safe_open(st_file, framework="pt") as f:
  211. for name in f.keys():
  212. param = f.get_slice(name)
  213. yield name, param
  214. else:
  215. for bin_file in hf_weights_files:
  216. state = torch.load(bin_file, map_location="cpu")
  217. for name, param in state.items():
  218. yield name, param
  219. del state
  220. torch.cuda.empty_cache()
  221. def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
  222. """convert PySafeSlice object from safetensors to torch.Tensor
  223. PySafeSlice object supports indexing, which is done before loading the
  224. actual tensor and can reduce the amount of memory being read into the
  225. memory. However, it does not support more advanced functionalities
  226. like `.view()` or `.t()`. Therefore, if we need to modify the loaded
  227. tensor with these more complicated operators, we need to convert to
  228. tensor first.
  229. """
  230. if not isinstance(x, torch.Tensor):
  231. x = x[:]
  232. return x
  233. def load_padded_tensor_parallel_vocab(
  234. param: torch.Tensor,
  235. loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
  236. tensor_model_parallel_rank: int,
  237. ) -> None:
  238. shard_size = param.shape[0]
  239. start_idx = tensor_model_parallel_rank * shard_size
  240. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  241. loaded_weight = loaded_weight[start_idx:end_idx]
  242. loaded_weight = convert_pyslice_to_tensor(loaded_weight)
  243. param.data[:loaded_weight.shape[0]].copy_(loaded_weight)
  244. def load_tensor_parallel_weights(
  245. param: torch.Tensor,
  246. loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
  247. param_name: str,
  248. column_parallel_weight_names: List[str],
  249. row_parallel_weight_names: List[str],
  250. tensor_model_parallel_rank: int,
  251. ) -> None:
  252. for p in column_parallel_weight_names:
  253. if p in param_name:
  254. shard_size = param.shape[0]
  255. start_idx = tensor_model_parallel_rank * shard_size
  256. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  257. loaded_weight = loaded_weight[start_idx:end_idx]
  258. break
  259. for p in row_parallel_weight_names:
  260. if p in param_name:
  261. shard_size = param.shape[-1]
  262. start_idx = tensor_model_parallel_rank * shard_size
  263. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  264. if isinstance(loaded_weight, torch.Tensor):
  265. loaded_weight = loaded_weight[..., start_idx:end_idx]
  266. else:
  267. index = [slice(None)] * (len(loaded_weight.get_shape()) -
  268. 1) + [slice(start_idx, end_idx)]
  269. loaded_weight = loaded_weight[index]
  270. break
  271. loaded_weight = convert_pyslice_to_tensor(loaded_weight)
  272. assert param.shape == loaded_weight.shape, (
  273. f"{param_name} shape mismatch between model and checkpoint: "
  274. f"{param.shape} != {loaded_weight.shape}")
  275. param.data.copy_(loaded_weight)
  276. def initialize_dummy_weights(
  277. model: torch.nn.Module,
  278. low: float = -1e-3,
  279. high: float = 1e-3,
  280. ) -> None:
  281. """Initialize model weights with random values.
  282. The model weights must be randomly initialized for accurate performance
  283. measurements. Additionally, the model weights should not cause NaNs in the
  284. forward pass. We empirically found that initializing the weights with
  285. values between -1e-3 and 1e-3 works well for most models.
  286. """
  287. for param in model.state_dict().values():
  288. if torch.is_floating_point(param):
  289. param.data.uniform_(low, high)
  290. def get_parallel_weight(model: torch.nn.Module):
  291. if model.quant_config is None:
  292. column_weight_suffixes = ["weight", "bias"]
  293. row_weight_suffixes = ["weight"]
  294. ignore_weight_suffixes = []
  295. else:
  296. column_weight_suffixes = model.quant_config.get_column_tp_tensor_names(
  297. )
  298. row_weight_suffixes = model.quant_config.get_row_tp_tensor_names()
  299. ignore_weight_suffixes = model.quant_config.get_ignore_tensor_names()
  300. column_parallel_weights: List[str] = []
  301. for layer in model.column_parallel_layers:
  302. for suffix in column_weight_suffixes:
  303. column_parallel_weights.append(f"{layer}.{suffix}")
  304. row_parallel_weights: List[str] = []
  305. for layer in model.row_parallel_layers:
  306. for suffix in row_weight_suffixes:
  307. row_parallel_weights.append(f"{layer}.{suffix}")
  308. if hasattr(model, "parallel_vocab_layers"):
  309. for layer in model.parallel_vocab_layers:
  310. for suffix in ["weight", "bias"]:
  311. column_parallel_weights.append(f"{layer}.{suffix}")
  312. return column_parallel_weights, row_parallel_weights, ignore_weight_suffixes