hf_downloader.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. """Utilities for downloading and initializing model weights."""
  2. import fnmatch
  3. import glob
  4. import json
  5. import os
  6. from collections import defaultdict
  7. from typing import Any, Iterable, Iterator, List, Optional, Tuple
  8. import filelock
  9. import huggingface_hub.constants
  10. import numpy as np
  11. import torch
  12. from huggingface_hub import HfFileSystem, snapshot_download
  13. from loguru import logger
  14. from safetensors.torch import load_file, safe_open, save_file
  15. from tqdm.auto import tqdm
  16. from transformers import PretrainedConfig, AutoModelForCausalLM
  17. from aphrodite.common.config import ModelConfig
  18. from aphrodite.common.gguf import (GGUFReader, get_tensor_name_map,
  19. MODEL_ARCH_NAMES)
  20. from aphrodite.common.logger import get_loading_progress_bar
  21. from aphrodite.modeling.layers.quantization import (QuantizationConfig,
  22. get_quantization_config)
  23. from aphrodite.modeling.layers.quantization.schema import QuantParamSchema
  24. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  25. get_tensor_model_parallel_world_size)
  26. _xdg_cache_home = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
  27. _aphrodite_filelocks_path = os.path.join(_xdg_cache_home, "aphrodite/locks/")
  28. def enable_hf_transfer():
  29. """automatically activates hf_transfer
  30. """
  31. if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
  32. try:
  33. # enable hf hub transfer if available
  34. import hf_transfer # type: ignore # noqa
  35. huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
  36. except ImportError:
  37. pass
  38. enable_hf_transfer()
  39. class Disabledtqdm(tqdm):
  40. def __init__(self, *args, **kwargs):
  41. super().__init__(*args, **kwargs, disable=True)
  42. def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
  43. lock_dir = cache_dir if cache_dir is not None else _aphrodite_filelocks_path
  44. os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
  45. lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
  46. lock = filelock.SoftFileLock(os.path.join(lock_dir, lock_file_name))
  47. return lock
  48. def _shared_pointers(tensors):
  49. ptrs = defaultdict(list)
  50. for k, v in tensors.items():
  51. ptrs[v.data_ptr()].append(k)
  52. failing = []
  53. for _, names in ptrs.items():
  54. if len(names) > 1:
  55. failing.append(names)
  56. return failing
  57. def convert_bin_to_safetensor_file(
  58. pt_filename: str,
  59. sf_filename: str,
  60. ) -> None:
  61. loaded = torch.load(pt_filename, map_location="cpu")
  62. if "state_dict" in loaded:
  63. loaded = loaded["state_dict"]
  64. shared = _shared_pointers(loaded)
  65. for shared_weights in shared:
  66. for name in shared_weights[1:]:
  67. loaded.pop(name)
  68. # For tensors to be contiguous
  69. loaded = {k: v.contiguous() for k, v in loaded.items()}
  70. dirname = os.path.dirname(sf_filename)
  71. os.makedirs(dirname, exist_ok=True)
  72. save_file(loaded, sf_filename, metadata={"format": "pt"})
  73. # check file size
  74. sf_size = os.stat(sf_filename).st_size
  75. pt_size = os.stat(pt_filename).st_size
  76. if (sf_size - pt_size) / pt_size > 0.01:
  77. raise RuntimeError(f"""The file size different is more than 1%:
  78. - {sf_filename}: {sf_size}
  79. - {pt_filename}: {pt_size}
  80. """)
  81. # check if the tensors are the same
  82. reloaded = load_file(sf_filename)
  83. for k in loaded:
  84. pt_tensor = loaded[k]
  85. sf_tensor = reloaded[k]
  86. if not torch.equal(pt_tensor, sf_tensor):
  87. raise RuntimeError(f"The output tensors do not match for key {k}")
  88. # TODO: Move this to another place.
  89. def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
  90. quant_cls = get_quantization_config(model_config.quantization)
  91. # Read the quantization config from the HF model config, if available.
  92. # if the quantization if "gguf", we skip and return quant_cls()
  93. if model_config.quantization in ["exl2", "gguf"]:
  94. return quant_cls()
  95. hf_quant_config = getattr(model_config.hf_config, "quantization_config",
  96. None)
  97. if hf_quant_config is not None:
  98. return quant_cls.from_config(hf_quant_config)
  99. model_name_or_path = model_config.model
  100. is_local = os.path.isdir(model_name_or_path)
  101. if not is_local:
  102. # Download the config files.
  103. with get_lock(model_name_or_path, model_config.download_dir):
  104. hf_folder = snapshot_download(
  105. model_name_or_path,
  106. revision=model_config.revision,
  107. allow_patterns="*.json",
  108. cache_dir=model_config.download_dir,
  109. tqdm_class=Disabledtqdm,
  110. )
  111. else:
  112. hf_folder = model_name_or_path
  113. config_files = glob.glob(os.path.join(hf_folder, "*.json"))
  114. quant_config_files = [
  115. f for f in config_files if any(
  116. f.endswith(x) for x in quant_cls.get_config_filenames())
  117. ]
  118. if len(quant_config_files) == 0:
  119. raise ValueError(
  120. f"Cannot find the config file for {model_config.quantization}")
  121. if len(quant_config_files) > 1:
  122. raise ValueError(
  123. f"Found multiple config files for {model_config.quantization}: "
  124. f"{quant_config_files}")
  125. quant_config_file = quant_config_files[0]
  126. with open(quant_config_file, "r") as f:
  127. config = json.load(f)
  128. return quant_cls.from_config(config)
  129. def prepare_hf_model_weights(
  130. model_name_or_path: str,
  131. cache_dir: Optional[str] = None,
  132. load_format: str = "auto",
  133. fall_back_to_pt: bool = True,
  134. revision: Optional[str] = None,
  135. ) -> Tuple[str, List[str], bool]:
  136. # Download model weights from huggingface.
  137. is_local = os.path.isdir(model_name_or_path)
  138. use_safetensors = False
  139. # Some quantized models use .pt files for storing the weights.
  140. if load_format == "auto":
  141. allow_patterns = ["*.safetensors", "*.bin"]
  142. elif load_format == "safetensors":
  143. use_safetensors = True
  144. allow_patterns = ["*.safetensors"]
  145. elif load_format == "pt":
  146. allow_patterns = ["*.pt"]
  147. elif load_format == "npcache":
  148. allow_patterns = ["*.bin"]
  149. else:
  150. raise ValueError(f"Unknown load_format: {load_format}")
  151. if fall_back_to_pt:
  152. allow_patterns += ["*.pt"]
  153. if not is_local:
  154. # Before we download we look at that is available:
  155. fs = HfFileSystem()
  156. file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
  157. # depending on what is available we download different things
  158. for pattern in allow_patterns:
  159. matching = fnmatch.filter(file_list, pattern)
  160. if len(matching) > 0:
  161. allow_patterns = [pattern]
  162. break
  163. logger.info(f"Using model weights format {allow_patterns}")
  164. # Use file lock to prevent multiple processes from
  165. # downloading the same model weights at the same time.
  166. with get_lock(model_name_or_path, cache_dir):
  167. hf_folder = snapshot_download(
  168. model_name_or_path,
  169. allow_patterns=allow_patterns,
  170. cache_dir=cache_dir,
  171. tqdm_class=Disabledtqdm,
  172. revision=revision,
  173. )
  174. else:
  175. hf_folder = model_name_or_path
  176. hf_weights_files: List[str] = []
  177. for pattern in allow_patterns:
  178. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  179. if len(hf_weights_files) > 0:
  180. if pattern == "*.safetensors":
  181. use_safetensors = True
  182. break
  183. if not use_safetensors:
  184. # Exclude files that are not needed for inference.
  185. # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
  186. blacklist = [
  187. "training_args.bin",
  188. "optimizer.bin",
  189. "optimizer.pt",
  190. "scheduler.pt",
  191. "scaler.pt",
  192. "trainer_state.json",
  193. "hidden_states.safetensors", # exllamav2
  194. ]
  195. hf_weights_files = [
  196. f for f in hf_weights_files
  197. if not any(f.endswith(x) for x in blacklist)
  198. ]
  199. if len(hf_weights_files) == 0:
  200. raise RuntimeError(
  201. f"Cannot find any model weights with `{model_name_or_path}`")
  202. return hf_folder, hf_weights_files, use_safetensors
  203. def convert_gguf_to_state_dict(checkpoint, config):
  204. if not os.path.isfile(checkpoint):
  205. raise RuntimeError(
  206. f"Cannot find any model weights with `{checkpoint}`")
  207. model_type = config.model_type
  208. # hack: ggufs have a different name than transformers
  209. if model_type == "cohere":
  210. model_type = "command-r"
  211. arch = None
  212. for key, value in MODEL_ARCH_NAMES.items():
  213. if value == model_type:
  214. arch = key
  215. break
  216. if arch is None:
  217. raise RuntimeError(f"Unknown model_type: {model_type}")
  218. num_layers = config.num_hidden_layers
  219. name_map = get_tensor_name_map(arch, num_layers)
  220. with torch.device("meta"):
  221. dummy_model = AutoModelForCausalLM.from_config(config)
  222. state_dict = dummy_model.state_dict()
  223. gguf_to_hf_name_map = {}
  224. keys_to_remove = []
  225. for hf_name in state_dict:
  226. name, suffix = hf_name.rsplit(".", 1)
  227. gguf_name = name_map.get_name(name)
  228. if gguf_name:
  229. gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
  230. elif name == "lm_head":
  231. keys_to_remove.append(hf_name)
  232. logger.warning(
  233. f"GGUF tensor name for {hf_name} not found, "
  234. "this is normal if the model uses tie word embeddings.")
  235. else:
  236. logger.warning(
  237. f"GGUF tensor name for {hf_name} in hf state_dict not found.")
  238. for key in keys_to_remove:
  239. state_dict.pop(key)
  240. result = GGUFReader(checkpoint)
  241. with get_loading_progress_bar() as progress:
  242. task = progress.add_task(
  243. "[cyan]Converting GGUF tensors to PyTorch...",
  244. total=len(result.tensors),
  245. )
  246. for ts in result.tensors:
  247. try:
  248. hf_name = gguf_to_hf_name_map[ts.name]
  249. except KeyError:
  250. logger.warning(
  251. f"hf tensor name for {ts.name} in GGUF not found.")
  252. continue
  253. data = torch.tensor(ts.data)
  254. if state_dict[hf_name].dim() == 2:
  255. data = data.view(state_dict[hf_name].shape[0], -1)
  256. state_dict[hf_name] = data
  257. weight_type = torch.tensor(int(ts.tensor_type), dtype=torch.int)
  258. if weight_type > 1:
  259. state_dict[hf_name.replace("weight",
  260. "weight_type")] = weight_type
  261. progress.update(task, advance=1)
  262. return state_dict
  263. def hf_model_weights_iterator(
  264. model_name_or_path: str,
  265. cache_dir: Optional[str] = None,
  266. load_format: str = "auto",
  267. revision: Optional[str] = None,
  268. config: Optional[PretrainedConfig] = None,
  269. fall_back_to_pt: Optional[bool] = True,
  270. ) -> Iterator[Tuple[str, torch.Tensor]]:
  271. if model_name_or_path.endswith("gguf"):
  272. for name, param in convert_gguf_to_state_dict(model_name_or_path,
  273. config).items():
  274. yield name, param
  275. return
  276. hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
  277. model_name_or_path,
  278. cache_dir=cache_dir,
  279. load_format=load_format,
  280. fall_back_to_pt=fall_back_to_pt,
  281. revision=revision,
  282. )
  283. if load_format == "npcache":
  284. # Currently np_cache only support *.bin checkpoints
  285. assert use_safetensors is False
  286. # Convert the model weights from torch tensors to numpy arrays for
  287. # faster loading.
  288. np_folder = os.path.join(hf_folder, "np")
  289. os.makedirs(np_folder, exist_ok=True)
  290. weight_names_file = os.path.join(np_folder, "weight_names.json")
  291. # Use file lock to prevent multiple processes from
  292. # dumping the same model weights to numpy at the same time.
  293. with get_lock(model_name_or_path, cache_dir):
  294. if not os.path.exists(weight_names_file):
  295. weight_names = []
  296. for bin_file in hf_weights_files:
  297. state = torch.load(bin_file, map_location="cpu")
  298. for name, param in state.items():
  299. param_path = os.path.join(np_folder, name)
  300. with open(param_path, "wb") as f:
  301. np.save(f, param.cpu().detach().numpy())
  302. weight_names.append(name)
  303. with open(weight_names_file, "w") as f:
  304. json.dump(weight_names, f)
  305. with open(weight_names_file, "r") as f:
  306. weight_names = json.load(f)
  307. for name in weight_names:
  308. param_path = os.path.join(np_folder, name)
  309. with open(param_path, "rb") as f:
  310. param = np.load(f)
  311. yield name, torch.from_numpy(param)
  312. elif use_safetensors:
  313. for st_file in hf_weights_files:
  314. with safe_open(st_file, framework="pt") as f:
  315. for name in f.keys(): # noqa: SIM118
  316. param = f.get_tensor(name)
  317. yield name, param
  318. else:
  319. for bin_file in hf_weights_files:
  320. state = torch.load(bin_file, map_location="cpu")
  321. for name, param in state.items():
  322. yield name, param
  323. del state
  324. torch.cuda.empty_cache()
  325. def kv_cache_scales_loader(
  326. filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
  327. model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
  328. """
  329. A simple utility to read in KV cache scaling factors that have been
  330. previously serialized to disk. Used by the model to populate the appropriate
  331. KV cache scaling factors. The serialization should represent a dictionary
  332. whose keys are the TP ranks and values are another dictionary mapping layers
  333. to their KV cache scaling factors.
  334. Keep this function in sync with the output of examples/fp8/extract_scales.py
  335. """
  336. try:
  337. with open(filename) as f:
  338. context = {
  339. "model_type": model_type,
  340. "num_hidden_layers": num_hidden_layers,
  341. "tp_rank": tp_rank,
  342. "tp_size": tp_size,
  343. }
  344. schema_dct = json.load(f)
  345. schema = QuantParamSchema.model_validate(schema_dct,
  346. context=context)
  347. layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
  348. return layer_scales_map.items()
  349. except FileNotFoundError:
  350. logger.error(f"File or directory '{filename}' not found.")
  351. except json.JSONDecodeError:
  352. logger.error(f"Error decoding JSON in file '{filename}'.")
  353. except Exception as e:
  354. logger.error(f"An error occurred while reading '{filename}': {e}")
  355. # This section is reached if and only if any of the excepts are hit
  356. # Return an empty iterable (list) => no KV cache scales are loaded
  357. # which ultimately defaults to 1.0 scales
  358. logger.warning("Defaulting to KV cache scaling factors = 1.0 "
  359. f"for all layers in TP rank {tp_rank} "
  360. "as an error occurred during loading.")
  361. return []
  362. def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
  363. """convert PySafeSlice object from safetensors to torch.Tensor
  364. PySafeSlice object supports indexing, which is done before loading the
  365. actual tensor and can reduce the amount of memory being read into the
  366. memory. However, it does not support more advanced functionalities
  367. like `.view()` or `.t()`. Therefore, if we need to modify the loaded
  368. tensor with these more complicated operators, we need to convert to
  369. tensor first.
  370. """
  371. if not isinstance(x, torch.Tensor):
  372. x = x[:]
  373. return x
  374. def default_weight_loader(param: torch.Tensor,
  375. loaded_weight: torch.Tensor) -> None:
  376. """Default weight loader."""
  377. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  378. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  379. assert param.size() == loaded_weight.size()
  380. param.data.copy_(loaded_weight)
  381. def initialize_dummy_weights(
  382. model: torch.nn.Module,
  383. low: float = -1e-3,
  384. high: float = 1e-3,
  385. ) -> None:
  386. """Initialize model weights with random values.
  387. The model weights must be randomly initialized for accurate performance
  388. measurements. Additionally, the model weights should not cause NaNs in the
  389. forward pass. We empirically found that initializing the weights with
  390. values between -1e-3 and 1e-3 works well for most models.
  391. """
  392. for param in model.state_dict().values():
  393. if torch.is_floating_point(param):
  394. param.data.uniform_(low, high)
  395. # Split the exl2 weight at group boundary
  396. def post_init_exl2(layer):
  397. q_groups = layer.q_groups
  398. num_groups = q_groups.shape[0] // 2
  399. input_size = layer.q_invperm.shape[0]
  400. tp_rank = get_tensor_model_parallel_rank()
  401. tp_size = get_tensor_model_parallel_world_size()
  402. rows = 0
  403. # row_number, qrow_number, group_number
  404. splits = [(0, 0, 0)]
  405. index = 1
  406. for i in range(num_groups - 1):
  407. bits = q_groups[i * 2].item()
  408. qrows = (q_groups[i * 2 + 3] - q_groups[i * 2 + 1]).item()
  409. rows += qrows * 32 // bits
  410. if rows >= input_size // tp_size * index:
  411. splits.append((rows, q_groups[i * 2 + 3].item(), i + 1))
  412. index += 1
  413. splits.append((input_size, layer.q_weight.shape[0], num_groups))
  414. shard_qweight = torch.nn.Parameter(
  415. layer.q_weight[splits[tp_rank][1]:splits[tp_rank + 1][1]].clone(),
  416. requires_grad=False,
  417. )
  418. del layer.q_weight
  419. layer.linear_weights["q_weight"] = shard_qweight
  420. shard_qgroups = torch.nn.Parameter(
  421. layer.q_groups[splits[tp_rank][2] * 2:splits[tp_rank + 1][2] *
  422. 2].clone(),
  423. requires_grad=False,
  424. )
  425. shard_qgroups[1::2] -= int(shard_qgroups[1])
  426. del layer.q_groups
  427. layer.linear_weights["q_groups"] = shard_qgroups
  428. shard_qscale = torch.nn.Parameter(
  429. layer.q_scale[splits[tp_rank][2]:splits[tp_rank + 1][2]].clone(),
  430. requires_grad=False,
  431. )
  432. del layer.q_scale
  433. layer.linear_weights["q_scale"] = shard_qscale
  434. shard_qscale_max = torch.nn.Parameter(
  435. layer.q_scale_max[splits[tp_rank][2]:splits[tp_rank + 1][2]].clone(),
  436. requires_grad=False,
  437. )
  438. del layer.q_scale_max
  439. layer.linear_weights["q_scale_max"] = shard_qscale_max
  440. q_perm = torch.argsort(layer.q_invperm).to(torch.short)
  441. layer.linear_weights["q_perm"] = q_perm[splits[tp_rank][0]:splits[tp_rank +
  442. 1][0]]