1
0

hf_downloader.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. """Utilities for downloading and initializing model weights."""
  2. import filelock
  3. import glob
  4. import fnmatch
  5. import json
  6. import os
  7. from collections import defaultdict
  8. from typing import Any, Iterator, List, Optional, Tuple
  9. from loguru import logger
  10. from huggingface_hub import snapshot_download, HfFileSystem
  11. import numpy as np
  12. from safetensors.torch import load_file, save_file, safe_open
  13. import torch
  14. from transformers import PretrainedConfig
  15. from tqdm.auto import tqdm
  16. from aphrodite.common.config import ModelConfig
  17. from aphrodite.common.logger import get_loading_progress_bar
  18. from aphrodite.common.gguf import GGUFReader
  19. from aphrodite.modeling.layers.quantization import (get_quantization_config,
  20. QuantizationConfig)
  21. class Disabledtqdm(tqdm): # pylint: disable=inconsistent-mro
  22. def __init__(self, *args, **kwargs):
  23. super().__init__(*args, **kwargs, disable=True)
  24. def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
  25. lock_dir = cache_dir if cache_dir is not None else "/tmp"
  26. lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
  27. lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
  28. return lock
  29. def _shared_pointers(tensors):
  30. ptrs = defaultdict(list)
  31. for k, v in tensors.items():
  32. ptrs[v.data_ptr()].append(k)
  33. failing = []
  34. for _, names in ptrs.items():
  35. if len(names) > 1:
  36. failing.append(names)
  37. return failing
  38. def convert_bin_to_safetensor_file(
  39. pt_filename: str,
  40. sf_filename: str,
  41. ) -> None:
  42. loaded = torch.load(pt_filename, map_location="cpu")
  43. if "state_dict" in loaded:
  44. loaded = loaded["state_dict"]
  45. shared = _shared_pointers(loaded)
  46. for shared_weights in shared:
  47. for name in shared_weights[1:]:
  48. loaded.pop(name)
  49. # For tensors to be contiguous
  50. loaded = {k: v.contiguous() for k, v in loaded.items()}
  51. dirname = os.path.dirname(sf_filename)
  52. os.makedirs(dirname, exist_ok=True)
  53. save_file(loaded, sf_filename, metadata={"format": "pt"})
  54. # check file size
  55. sf_size = os.stat(sf_filename).st_size
  56. pt_size = os.stat(pt_filename).st_size
  57. if (sf_size - pt_size) / pt_size > 0.01:
  58. raise RuntimeError(f"""The file size different is more than 1%:
  59. - {sf_filename}: {sf_size}
  60. - {pt_filename}: {pt_size}
  61. """)
  62. # check if the tensors are the same
  63. reloaded = load_file(sf_filename)
  64. for k in loaded:
  65. pt_tensor = loaded[k]
  66. sf_tensor = reloaded[k]
  67. if not torch.equal(pt_tensor, sf_tensor):
  68. raise RuntimeError(f"The output tensors do not match for key {k}")
  69. # TODO: Move this to another place.
  70. def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
  71. quant_cls = get_quantization_config(model_config.quantization)
  72. # Read the quantization config from the HF model config, if available.
  73. # if the quantization if "gguf", we skip and return quant_cls()
  74. if model_config.quantization in ["exl2", "gguf"]:
  75. return quant_cls()
  76. hf_quant_config = getattr(model_config.hf_config, "quantization_config",
  77. None)
  78. if hf_quant_config is not None:
  79. return quant_cls.from_config(hf_quant_config)
  80. model_name_or_path = model_config.model
  81. is_local = os.path.isdir(model_name_or_path)
  82. if not is_local:
  83. # Download the config files.
  84. with get_lock(model_name_or_path, model_config.download_dir):
  85. hf_folder = snapshot_download(model_name_or_path,
  86. revision=model_config.revision,
  87. allow_patterns="*.json",
  88. cache_dir=model_config.download_dir,
  89. tqdm_class=Disabledtqdm)
  90. else:
  91. hf_folder = model_name_or_path
  92. config_files = glob.glob(os.path.join(hf_folder, "*.json"))
  93. quant_config_files = [
  94. f for f in config_files if any(
  95. f.endswith(x) for x in quant_cls.get_config_filenames())
  96. ]
  97. if len(quant_config_files) == 0:
  98. raise ValueError(
  99. f"Cannot find the config file for {model_config.quantization}")
  100. if len(quant_config_files) > 1:
  101. raise ValueError(
  102. f"Found multiple config files for {model_config.quantization}: "
  103. f"{quant_config_files}")
  104. quant_config_file = quant_config_files[0]
  105. with open(quant_config_file, "r") as f:
  106. config = json.load(f)
  107. return quant_cls.from_config(config)
  108. def prepare_hf_model_weights(
  109. model_name_or_path: str,
  110. cache_dir: Optional[str] = None,
  111. load_format: str = "auto",
  112. fall_back_to_pt: bool = True,
  113. revision: Optional[str] = None,
  114. ) -> Tuple[str, List[str], bool]:
  115. # Download model weights from huggingface.
  116. is_local = os.path.isdir(model_name_or_path)
  117. use_safetensors = False
  118. # Some quantized models use .pt files for storing the weights.
  119. if load_format == "auto":
  120. allow_patterns = ["*.safetensors", "*.bin"]
  121. elif load_format == "safetensors":
  122. use_safetensors = True
  123. allow_patterns = ["*.safetensors"]
  124. elif load_format == "pt":
  125. allow_patterns = ["*.pt"]
  126. elif load_format == "npcache":
  127. allow_patterns = ["*.bin"]
  128. else:
  129. raise ValueError(f"Unknown load_format: {load_format}")
  130. if fall_back_to_pt:
  131. allow_patterns += ["*.pt"]
  132. if not is_local:
  133. # Before we download we look at that is available:
  134. fs = HfFileSystem()
  135. file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
  136. # depending on what is available we download different things
  137. for pattern in allow_patterns:
  138. matching = fnmatch.filter(file_list, pattern)
  139. if len(matching) > 0:
  140. allow_patterns = [pattern]
  141. break
  142. logger.info(f"Downloading model weights {allow_patterns}")
  143. # Use file lock to prevent multiple processes from
  144. # downloading the same model weights at the same time.
  145. with get_lock(model_name_or_path, cache_dir):
  146. hf_folder = snapshot_download(model_name_or_path,
  147. allow_patterns=allow_patterns,
  148. cache_dir=cache_dir,
  149. tqdm_class=Disabledtqdm,
  150. revision=revision)
  151. else:
  152. hf_folder = model_name_or_path
  153. hf_weights_files: List[str] = []
  154. for pattern in allow_patterns:
  155. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  156. if len(hf_weights_files) > 0:
  157. if pattern == "*.safetensors":
  158. use_safetensors = True
  159. break
  160. if not use_safetensors:
  161. # Exclude files that are not needed for inference.
  162. # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
  163. blacklist = [
  164. "training_args.bin",
  165. "optimizer.bin",
  166. "optimizer.pt",
  167. "scheduler.pt",
  168. "scaler.pt",
  169. "trainer_state.json",
  170. ]
  171. hf_weights_files = [
  172. f for f in hf_weights_files
  173. if not any(f.endswith(x) for x in blacklist)
  174. ]
  175. if len(hf_weights_files) == 0:
  176. raise RuntimeError(
  177. f"Cannot find any model weights with `{model_name_or_path}`")
  178. return hf_folder, hf_weights_files, use_safetensors
  179. def convert_gguf_to_state_dict(checkpoint, config):
  180. if not os.path.isfile(checkpoint):
  181. raise RuntimeError(
  182. f"Cannot find any model weights with `{checkpoint}`")
  183. result = GGUFReader(checkpoint)
  184. # write tensor
  185. kv_dim = (config.hidden_size // config.num_attention_heads *
  186. config.num_key_value_heads)
  187. tensor_mapping = {
  188. "token_embd": ("model.embed_tokens", config.vocab_size),
  189. "output": ("lm_head", config.vocab_size),
  190. "output_norm": ("model.norm", -1),
  191. "blk.{bid}.attn_norm": ("model.layers.{bid}.input_layernorm", -1),
  192. "blk.{bid}.attn_q": ("model.layers.{bid}.self_attn.q_proj",
  193. config.hidden_size),
  194. "blk.{bid}.attn_k": ("model.layers.{bid}.self_attn.k_proj", kv_dim),
  195. "blk.{bid}.attn_v": ("model.layers.{bid}.self_attn.v_proj", kv_dim),
  196. "blk.{bid}.attn_output": ("model.layers.{bid}.self_attn.o_proj",
  197. config.hidden_size),
  198. "blk.{bid}.attn_rot_embd":
  199. ("model.layers.{bid}.self_attn.rotary_emb.inv_freq", -1),
  200. "blk.{bid}.ffn_norm": ("model.layers.{bid}.post_attention_layernorm",
  201. -1),
  202. "blk.{bid}.ffn_up": ("model.layers.{bid}.mlp.up_proj",
  203. config.intermediate_size),
  204. "blk.{bid}.ffn_down": ("model.layers.{bid}.mlp.down_proj",
  205. config.hidden_size),
  206. "blk.{bid}.ffn_gate": ("model.layers.{bid}.mlp.gate_proj",
  207. config.intermediate_size),
  208. "blk.{bid}.ffn_up.{xid}":
  209. ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w3",
  210. config.intermediate_size),
  211. "blk.{bid}.ffn_down.{xid}":
  212. ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w2",
  213. config.hidden_size),
  214. "blk.{bid}.ffn_gate.{xid}":
  215. ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w1",
  216. config.intermediate_size),
  217. "blk.{bid}.ffn_gate_inp": ("model.layers.{bid}.block_sparse_moe.gate",
  218. config.num_local_experts if hasattr(
  219. config, "num_local_experts") else -1),
  220. }
  221. mapping = {}
  222. # This is how llama.cpp handles name mapping,
  223. # it's better to use regex match instead doe
  224. max_block_num = 200
  225. max_expert_num = 8
  226. for k, v in tensor_mapping.items():
  227. for i in range(max_block_num):
  228. for j in range(max_expert_num):
  229. fk = k.format(bid=i, xid=j)
  230. fv = v[0].format(bid=i, xid=j)
  231. if k not in mapping:
  232. mapping[fk] = (fv, v[1])
  233. state_dict = {}
  234. with get_loading_progress_bar() as progress:
  235. task = progress.add_task("[cyan]Converting GGUF tensors to PyTorch...",
  236. total=len(result.tensors))
  237. for ts in result.tensors:
  238. weight_type = torch.tensor(int(ts.tensor_type), dtype=torch.int)
  239. layer, suffix = ts.name.rsplit(".", 1)
  240. new_key, output_dim = mapping[layer]
  241. new_key += f".{suffix}"
  242. data = torch.tensor(ts.data)
  243. if output_dim != -1:
  244. data = data.view(output_dim, -1)
  245. if weight_type > 1:
  246. state_dict[new_key.replace("weight",
  247. "weight_type")] = weight_type
  248. state_dict[new_key] = data
  249. progress.update(task, advance=1)
  250. return state_dict
  251. def hf_model_weights_iterator(
  252. model_name_or_path: str,
  253. cache_dir: Optional[str] = None,
  254. load_format: str = "auto",
  255. revision: Optional[str] = None,
  256. config: Optional[PretrainedConfig] = None,
  257. fall_back_to_pt: Optional[bool] = True,
  258. ) -> Iterator[Tuple[str, torch.Tensor]]:
  259. if model_name_or_path.endswith("gguf"):
  260. for name, param in convert_gguf_to_state_dict(model_name_or_path,
  261. config).items():
  262. yield name, param
  263. return
  264. hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
  265. model_name_or_path,
  266. cache_dir=cache_dir,
  267. load_format=load_format,
  268. fall_back_to_pt=fall_back_to_pt,
  269. revision=revision)
  270. if load_format == "npcache":
  271. # Currently np_cache only support *.bin checkpoints
  272. assert use_safetensors is False
  273. # Convert the model weights from torch tensors to numpy arrays for
  274. # faster loading.
  275. np_folder = os.path.join(hf_folder, "np")
  276. os.makedirs(np_folder, exist_ok=True)
  277. weight_names_file = os.path.join(np_folder, "weight_names.json")
  278. # Use file lock to prevent multiple processes from
  279. # dumping the same model weights to numpy at the same time.
  280. with get_lock(model_name_or_path, cache_dir):
  281. if not os.path.exists(weight_names_file):
  282. weight_names = []
  283. for bin_file in hf_weights_files:
  284. state = torch.load(bin_file, map_location="cpu")
  285. for name, param in state.items():
  286. param_path = os.path.join(np_folder, name)
  287. with open(param_path, "wb") as f:
  288. np.save(f, param.cpu().detach().numpy())
  289. weight_names.append(name)
  290. with open(weight_names_file, "w") as f:
  291. json.dump(weight_names, f)
  292. with open(weight_names_file, "r") as f:
  293. weight_names = json.load(f)
  294. for name in weight_names:
  295. param_path = os.path.join(np_folder, name)
  296. with open(param_path, "rb") as f:
  297. param = np.load(f)
  298. yield name, torch.from_numpy(param)
  299. elif use_safetensors:
  300. for st_file in hf_weights_files:
  301. with safe_open(st_file, framework="pt") as f:
  302. for name in f.keys(): # noqa: SIM118
  303. param = f.get_tensor(name)
  304. yield name, param
  305. else:
  306. for bin_file in hf_weights_files:
  307. state = torch.load(bin_file, map_location="cpu")
  308. for name, param in state.items():
  309. yield name, param
  310. del state
  311. torch.cuda.empty_cache()
  312. def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
  313. """convert PySafeSlice object from safetensors to torch.Tensor
  314. PySafeSlice object supports indexing, which is done before loading the
  315. actual tensor and can reduce the amount of memory being read into the
  316. memory. However, it does not support more advanced functionalities
  317. like `.view()` or `.t()`. Therefore, if we need to modify the loaded
  318. tensor with these more complicated operators, we need to convert to
  319. tensor first.
  320. """
  321. if not isinstance(x, torch.Tensor):
  322. x = x[:]
  323. return x
  324. def default_weight_loader(param: torch.Tensor,
  325. loaded_weight: torch.Tensor) -> None:
  326. """Default weight loader."""
  327. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  328. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  329. assert param.size() == loaded_weight.size()
  330. param.data.copy_(loaded_weight)
  331. def initialize_dummy_weights(
  332. model: torch.nn.Module,
  333. low: float = -1e-3,
  334. high: float = 1e-3,
  335. ) -> None:
  336. """Initialize model weights with random values.
  337. The model weights must be randomly initialized for accurate performance
  338. measurements. Additionally, the model weights should not cause NaNs in the
  339. forward pass. We empirically found that initializing the weights with
  340. values between -1e-3 and 1e-3 works well for most models.
  341. """
  342. for param in model.state_dict().values():
  343. if torch.is_floating_point(param):
  344. param.data.uniform_(low, high)