hf_downloader.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. """Utilities for downloading and initializing model weights."""
  2. import filelock
  3. import glob
  4. import json
  5. import os
  6. from collections import defaultdict
  7. from typing import Any, Iterator, List, Optional, Tuple
  8. from huggingface_hub import snapshot_download
  9. from safetensors.torch import load_file, save_file, safe_open
  10. import numpy as np
  11. import torch
  12. from tqdm.auto import tqdm
  13. from transformers import PretrainedConfig
  14. from aphrodite.common.logger import init_logger
  15. from aphrodite.modeling.quantization_utils import get_quant_class
  16. from aphrodite.modeling.quantization_utils.base import QuantizationConfig
  17. logger = init_logger(__name__)
  18. class Disabledtqdm(tqdm): # pylint: disable=inconsistent-mro
  19. def __init__(self, *args, **kwargs):
  20. super().__init__(*args, **kwargs, disable=True)
  21. def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
  22. lock_dir = cache_dir if cache_dir is not None else "/tmp"
  23. lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
  24. lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
  25. return lock
  26. def _shared_pointers(tensors):
  27. ptrs = defaultdict(list)
  28. for k, v in tensors.items():
  29. ptrs[v.data_ptr()].append(k)
  30. failing = []
  31. for _, names in ptrs.items():
  32. if len(names) > 1:
  33. failing.append(names)
  34. return failing
  35. def convert_bin_to_safetensor_file(
  36. pt_filename: str,
  37. sf_filename: str,
  38. ) -> None:
  39. loaded = torch.load(pt_filename, map_location="cpu")
  40. if "state_dict" in loaded:
  41. loaded = loaded["state_dict"]
  42. shared = _shared_pointers(loaded)
  43. for shared_weights in shared:
  44. for name in shared_weights[1:]:
  45. loaded.pop(name)
  46. # For tensors to be contiguous
  47. loaded = {k: v.contiguous() for k, v in loaded.items()}
  48. dirname = os.path.dirname(sf_filename)
  49. os.makedirs(dirname, exist_ok=True)
  50. save_file(loaded, sf_filename, metadata={"format": "pt"})
  51. # check file size
  52. sf_size = os.stat(sf_filename).st_size
  53. pt_size = os.stat(pt_filename).st_size
  54. if (sf_size - pt_size) / pt_size > 0.01:
  55. raise RuntimeError(f"""The file size different is more than 1%:
  56. - {sf_filename}: {sf_size}
  57. - {pt_filename}: {pt_size}
  58. """)
  59. # check if the tensors are the same
  60. reloaded = load_file(sf_filename)
  61. for k in loaded:
  62. pt_tensor = loaded[k]
  63. sf_tensor = reloaded[k]
  64. if not torch.equal(pt_tensor, sf_tensor):
  65. raise RuntimeError(f"The output tensors do not match for key {k}")
  66. # TODO: Move this to other place.
  67. def get_quant_config(
  68. quantization: str,
  69. model_name_or_path: str,
  70. hf_config: PretrainedConfig,
  71. cache_dir: Optional[str] = None,
  72. ) -> QuantizationConfig:
  73. if quantization == "gptq" and hasattr(hf_config, "quantization_config"):
  74. config = hf_config.quantization_config
  75. return get_quant_class(quantization).from_config(config)
  76. is_local = os.path.isdir(model_name_or_path)
  77. if not is_local:
  78. # Download the config files.
  79. with get_lock(model_name_or_path, cache_dir):
  80. hf_folder = snapshot_download(model_name_or_path,
  81. allow_patterns="*.json",
  82. cache_dir=cache_dir,
  83. tqdm_class=Disabledtqdm)
  84. else:
  85. hf_folder = model_name_or_path
  86. config_files = glob.glob(os.path.join(hf_folder, "*.json"))
  87. quant_cls = get_quant_class(quantization)
  88. quant_config_files = [
  89. f for f in config_files if any(
  90. f.endswith(x) for x in quant_cls.get_config_filenames())
  91. ]
  92. if len(quant_config_files) == 0:
  93. raise ValueError(f"Cannot find the config file for {quantization}")
  94. if len(quant_config_files) > 1:
  95. raise ValueError(f"Found multiple config files for {quantization}: "
  96. f"{quant_config_files}")
  97. quant_config_file = quant_config_files[0]
  98. with open(quant_config_file, "r") as f:
  99. config = json.load(f)
  100. return quant_cls.from_config(config)
  101. def prepare_hf_model_weights(
  102. model_name_or_path: str,
  103. cache_dir: Optional[str] = None,
  104. use_safetensors: bool = False,
  105. fall_back_to_pt: bool = True,
  106. revision: Optional[str] = None,
  107. ) -> Tuple[str, List[str], bool]:
  108. # Download model weights from huggingface.
  109. is_local = os.path.isdir(model_name_or_path)
  110. if use_safetensors:
  111. allow_patterns = ["*.safetensors"]
  112. else:
  113. # Some quantized models use .pt files for storing the weights.
  114. allow_patterns = ["*.bin", "*.pt"]
  115. if not is_local:
  116. # Use file lock to prevent multiple processes from
  117. # downloading the same model weights at the same time.
  118. with get_lock(model_name_or_path, cache_dir):
  119. hf_folder = snapshot_download(model_name_or_path,
  120. allow_patterns=allow_patterns,
  121. cache_dir=cache_dir,
  122. tqdm_class=Disabledtqdm,
  123. revision=revision)
  124. else:
  125. hf_folder = model_name_or_path
  126. hf_weights_files: List[str] = []
  127. for pattern in allow_patterns:
  128. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  129. if not use_safetensors:
  130. blacklist = [
  131. "training_args.bin",
  132. "optimizer.bin",
  133. "optimizer.pt",
  134. "scheduler.pt",
  135. "scaler.pt",
  136. "trainer_state.json",
  137. ]
  138. hf_weights_files = [
  139. f for f in hf_weights_files
  140. if not any(f.endswith(x) for x in blacklist)
  141. ]
  142. if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
  143. return prepare_hf_model_weights(model_name_or_path,
  144. cache_dir=cache_dir,
  145. use_safetensors=False,
  146. fall_back_to_pt=False,
  147. revision=revision)
  148. if len(hf_weights_files) == 0:
  149. raise RuntimeError(
  150. f"Cannot find any model weights with `{model_name_or_path}`")
  151. return hf_folder, hf_weights_files, use_safetensors
  152. def hf_model_weights_iterator(
  153. model_name_or_path: str,
  154. cache_dir: Optional[str] = None,
  155. load_format: str = "auto",
  156. revision: Optional[str] = None,
  157. ) -> Iterator[Tuple[str, torch.Tensor]]:
  158. use_safetensors = False
  159. use_np_cache = False
  160. fall_back_to_pt = False
  161. if load_format == "auto":
  162. use_safetensors = True
  163. fall_back_to_pt = True
  164. elif load_format == "safetensors":
  165. use_safetensors = True
  166. elif load_format == "pt":
  167. pass
  168. elif load_format == "npcache":
  169. use_np_cache = True
  170. else:
  171. raise ValueError(f"Unknown load_format: {load_format}")
  172. hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
  173. model_name_or_path,
  174. cache_dir=cache_dir,
  175. use_safetensors=use_safetensors,
  176. fall_back_to_pt=fall_back_to_pt,
  177. revision=revision)
  178. if use_np_cache:
  179. # Currently np_cache only support *.bin checkpoints
  180. assert use_safetensors is False
  181. # Convert the model weights from torch tensors to numpy arrays for
  182. # faster loading.
  183. np_folder = os.path.join(hf_folder, "np")
  184. os.makedirs(np_folder, exist_ok=True)
  185. weight_names_file = os.path.join(np_folder, "weight_names.json")
  186. # Use file lock to prevent multiple processes from
  187. # dumping the same model weights to numpy at the same time.
  188. with get_lock(model_name_or_path, cache_dir):
  189. if not os.path.exists(weight_names_file):
  190. weight_names = []
  191. for bin_file in hf_weights_files:
  192. state = torch.load(bin_file, map_location="cpu")
  193. for name, param in state.items():
  194. param_path = os.path.join(np_folder, name)
  195. with open(param_path, "wb") as f:
  196. np.save(f, param.cpu().detach().numpy())
  197. weight_names.append(name)
  198. with open(weight_names_file, "w") as f:
  199. json.dump(weight_names, f)
  200. with open(weight_names_file, "r") as f:
  201. weight_names = json.load(f)
  202. for name in weight_names:
  203. param_path = os.path.join(np_folder, name)
  204. with open(param_path, "rb") as f:
  205. param = np.load(f)
  206. yield name, torch.from_numpy(param)
  207. elif use_safetensors:
  208. for st_file in hf_weights_files:
  209. with safe_open(st_file, framework="pt") as f:
  210. for name in f.keys():
  211. param = f.get_slice(name)
  212. yield name, param
  213. else:
  214. for bin_file in hf_weights_files:
  215. state = torch.load(bin_file, map_location="cpu")
  216. for name, param in state.items():
  217. yield name, param
  218. del state
  219. torch.cuda.empty_cache()
  220. def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
  221. """convert PySafeSlice object from safetensors to torch.Tensor
  222. PySafeSlice object supports indexing, which is done before loading the
  223. actual tensor and can reduce the amount of memory being read into the
  224. memory. However, it does not support more advanced functionalities
  225. like `.view()` or `.t()`. Therefore, if we need to modify the loaded
  226. tensor with these more complicated operators, we need to convert to
  227. tensor first.
  228. """
  229. if not isinstance(x, torch.Tensor):
  230. x = x[:]
  231. return x
  232. def load_padded_tensor_parallel_vocab(
  233. param: torch.Tensor,
  234. loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
  235. tensor_model_parallel_rank: int,
  236. ) -> None:
  237. shard_size = param.shape[0]
  238. start_idx = tensor_model_parallel_rank * shard_size
  239. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  240. loaded_weight = loaded_weight[start_idx:end_idx]
  241. loaded_weight = convert_pyslice_to_tensor(loaded_weight)
  242. param.data[:loaded_weight.shape[0]].copy_(loaded_weight)
  243. def load_tensor_parallel_weights(
  244. param: torch.Tensor,
  245. loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
  246. param_name: str,
  247. column_parallel_weight_names: List[str],
  248. row_parallel_weight_names: List[str],
  249. tensor_model_parallel_rank: int,
  250. ) -> None:
  251. for p in column_parallel_weight_names:
  252. if p in param_name:
  253. shard_size = param.shape[0]
  254. start_idx = tensor_model_parallel_rank * shard_size
  255. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  256. loaded_weight = loaded_weight[start_idx:end_idx]
  257. break
  258. for p in row_parallel_weight_names:
  259. if p in param_name:
  260. shard_size = param.shape[-1]
  261. start_idx = tensor_model_parallel_rank * shard_size
  262. end_idx = (tensor_model_parallel_rank + 1) * shard_size
  263. if isinstance(loaded_weight, torch.Tensor):
  264. loaded_weight = loaded_weight[..., start_idx:end_idx]
  265. else:
  266. index = [slice(None)] * (len(loaded_weight.get_shape()) -
  267. 1) + [slice(start_idx, end_idx)]
  268. loaded_weight = loaded_weight[index]
  269. break
  270. loaded_weight = convert_pyslice_to_tensor(loaded_weight)
  271. assert param.shape == loaded_weight.shape, (
  272. f"{param_name} shape mismatch between model and checkpoint: "
  273. f"{param.shape} != {loaded_weight.shape}")
  274. param.data.copy_(loaded_weight)
  275. def initialize_dummy_weights(
  276. model: torch.nn.Module,
  277. low: float = -1e-3,
  278. high: float = 1e-3,
  279. ) -> None:
  280. """Initialize model weights with random values.
  281. The model weights must be randomly initialized for accurate performance
  282. measurements. Additionally, the model weights should not cause NaNs in the
  283. forward pass. We empirically found that initializing the weights with
  284. values between -1e-3 and 1e-3 works well for most models.
  285. """
  286. for param in model.state_dict().values():
  287. if torch.is_floating_point(param):
  288. param.data.uniform_(low, high)
  289. def get_parallel_weight(model: torch.nn.Module):
  290. if model.quant_config is None:
  291. column_weight_suffixes = ["weight", "bias"]
  292. row_weight_suffixes = ["weight"]
  293. else:
  294. column_weight_suffixes = (
  295. model.quant_config.get_col_parallel_tensor_names())
  296. row_weight_suffixes = (
  297. model.quant_config.get_row_parallel_tensor_names())
  298. column_parallel_weights: List[str] = []
  299. for layer in model.column_parallel_layers:
  300. for suffix in column_weight_suffixes:
  301. column_parallel_weights.append(f"{layer}.{suffix}")
  302. row_parallel_weights: List[str] = []
  303. for layer in model.row_parallel_layers:
  304. for suffix in row_weight_suffixes:
  305. row_parallel_weights.append(f"{layer}.{suffix}")
  306. if hasattr(model, "parallel_vocab_layers"):
  307. for layer in model.parallel_vocab_layers:
  308. for suffix in ["weight", "bias"]:
  309. column_parallel_weights.append(f"{layer}.{suffix}")
  310. return column_parallel_weights, row_parallel_weights