extract_scales.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. import argparse
  2. import glob
  3. import json
  4. import os
  5. from typing import Any, Callable, Dict, List, Optional, Tuple
  6. import numpy as np
  7. import torch
  8. from safetensors.torch import safe_open
  9. from aphrodite.quantization.schema import QuantParamSchema
  10. # Adapted from aphrodite/modeling/model_loader/weight_utils.py
  11. # The main differences are that we add the NPZ format and simplify
  12. # its functionality drastically for our purposes (e.g. we assume that
  13. # the quantized model exists locally and there is no need to download it)
  14. def _prepare_hf_weights(
  15. quantized_model_dir: str,
  16. load_format: str = "auto",
  17. fall_back_to_pt: bool = True,
  18. ) -> Tuple[List[str], bool]:
  19. if not os.path.isdir(quantized_model_dir):
  20. raise FileNotFoundError(
  21. f"The quantized model directory `{quantized_model_dir}` "
  22. "does not exist.")
  23. use_safetensors = False
  24. # Some quantized models use .pt files for storing the weights.
  25. if load_format == "auto":
  26. allow_patterns = ["*.safetensors", "*.bin"]
  27. elif load_format == "safetensors":
  28. use_safetensors = True
  29. allow_patterns = ["*.safetensors"]
  30. elif load_format == "pt":
  31. allow_patterns = ["*.pt"]
  32. elif load_format == "npz":
  33. allow_patterns = ["*.npz"]
  34. else:
  35. raise ValueError(f"Unknown load_format: {load_format}")
  36. if fall_back_to_pt:
  37. allow_patterns += ["*.pt"]
  38. hf_weights_files: List[str] = []
  39. for pattern in allow_patterns:
  40. hf_weights_files += glob.glob(
  41. os.path.join(quantized_model_dir, pattern))
  42. if len(hf_weights_files) > 0:
  43. if pattern == "*.safetensors":
  44. use_safetensors = True
  45. break
  46. if not use_safetensors:
  47. # Exclude files that are not needed for inference.
  48. # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
  49. blacklist = [
  50. "training_args.bin",
  51. "optimizer.bin",
  52. "optimizer.pt",
  53. "scheduler.pt",
  54. "scaler.pt",
  55. ]
  56. hf_weights_files = [
  57. f for f in hf_weights_files
  58. if not any(f.endswith(x) for x in blacklist)
  59. ]
  60. if len(hf_weights_files) == 0:
  61. raise RuntimeError(
  62. f"Cannot find any model weights with `{quantized_model_dir}`")
  63. return hf_weights_files, use_safetensors
  64. # Adapted from aphrodite/modeling/model_loader/weight_utils.py
  65. def _hf_tensorfile_iterator(filename: str, load_format: str,
  66. use_safetensors: bool):
  67. if load_format == "npz":
  68. assert not use_safetensors
  69. with np.load(filename) as data:
  70. for name in data.files:
  71. param = torch.from_numpy(data[name])
  72. yield name, param
  73. elif use_safetensors:
  74. with safe_open(filename, framework="pt") as f:
  75. for name in f.keys(): # NOQA: SIM118
  76. param = f.get_tensor(name)
  77. yield name, param
  78. else:
  79. state = torch.load(filename, map_location="cpu")
  80. for name, param in state.items():
  81. yield name, param
  82. del state
  83. torch.cuda.empty_cache()
  84. def _kv_scales_extractor(
  85. hf_tensor_files: List[str],
  86. use_safetensors: bool,
  87. rank_keyword: str = "rank",
  88. expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
  89. """
  90. Given a list of files containing tensor data, attempt to extract KV cache
  91. scales from these files. Intended as a helper function taking in the output
  92. from _prepare_hf_weights.
  93. Args:
  94. rank_keyword Matches the number immediately after this keyword in the
  95. tensor filename to determine the TP rank corresponding
  96. to said tensor file
  97. expected_tp_size If specified, the TP size of the tensor files is checked
  98. against this and an error is raised if they don't match.
  99. Returns a dictionary mapping TP ranks to their relevant KV cache scales.
  100. The per-rank scales are themselves represented as a dictionary of layer
  101. indices to the respective per-layer scale.
  102. """
  103. for char in rank_keyword:
  104. assert not char.isdecimal(
  105. ), f"Rank keyword {rank_keyword} contains a numeric character!"
  106. rank_scales_map: Dict[int, Dict[int, float]] = {}
  107. for tensor_file in hf_tensor_files:
  108. try:
  109. rank_idx = tensor_file.find(rank_keyword)
  110. if rank_idx != -1:
  111. start_idx = rank_idx + len(rank_keyword)
  112. stop_idx = start_idx
  113. while stop_idx < len(
  114. tensor_file) and tensor_file[stop_idx].isdecimal():
  115. stop_idx += 1
  116. if stop_idx == start_idx:
  117. raise RuntimeError("Did not find rank # in filename.")
  118. rank = int(tensor_file[start_idx:stop_idx])
  119. elif len(hf_tensor_files) == 1:
  120. # Since there is only one tensor file, we can assume
  121. # that it's intended for TP rank 0
  122. rank = 0
  123. else:
  124. raise RuntimeError(
  125. f"Filename does not contain '{rank_keyword}'.")
  126. except RuntimeError:
  127. print("Unable to determine TP rank "
  128. f"corresponding to file '{tensor_file}'")
  129. raise
  130. if rank not in rank_scales_map:
  131. layer_scales_map: Dict[int, float] = {}
  132. rank_scales_map[rank] = layer_scales_map
  133. else:
  134. raise RuntimeError(
  135. f"Tensor file '{tensor_file}' shares TP rank {rank} "
  136. "with another tensor file.")
  137. module_delimiter = ":" if args.load_format == "npz" else "."
  138. for name, param in _hf_tensorfile_iterator(tensor_file,
  139. args.load_format,
  140. use_safetensors):
  141. if "kv_cache_scaling_factor" in name:
  142. nums = [
  143. int(s) for s in name.split(module_delimiter)
  144. if s.isdecimal()
  145. ]
  146. assert len(
  147. nums) == 1, f"Could not determine layer idx for {name}"
  148. layer_idx = nums[0]
  149. assert layer_idx not in layer_scales_map, f"Duplicate scaling"\
  150. f" factor corresponding to layer {layer_idx}"
  151. try:
  152. layer_scales_map[layer_idx] = param.item()
  153. except RuntimeError:
  154. print(
  155. "This utility supports only per-tensor scalar scales "
  156. f"for now. The tensor\n {name} = {param} \nis an "
  157. "invalid scale factor.")
  158. raise
  159. if all(
  160. len(layer_scales_map) == 0
  161. for layer_scales_map in rank_scales_map.values()):
  162. # Note: this is true even if the rank_scales_map is empty
  163. print("WARNING: No KV cache scale factors found. No output saved.")
  164. return None
  165. empirical_tp_world_size = max(rank_scales_map.keys()) + 1
  166. if expected_tp_size is not None:
  167. assert expected_tp_size == empirical_tp_world_size, \
  168. f"User expected TP world size = {expected_tp_size} " \
  169. "from model but tool is expecting TP world size = " \
  170. f"{empirical_tp_world_size} from model instead."
  171. for i in range(empirical_tp_world_size):
  172. assert i in rank_scales_map, "Expected TP world size = "\
  173. f"{empirical_tp_world_size} but did not find KV " \
  174. f"cache scaling factors for TP rank {i}"
  175. print(f"Found TP world size = {empirical_tp_world_size} "
  176. "when extracting KV cache scales!")
  177. return rank_scales_map
  178. def _metadata_extractor(quantized_model_dir: str,
  179. metadata_extract_fns: \
  180. Dict[str, Callable[[Dict[str, Any]], Any]]) \
  181. -> Dict[str, Any]:
  182. """
  183. Given a directory containing quantized model files, this function
  184. aims to extract metadata from the JSON files within this directory.
  185. Each JSON file is expected to represent a dictionary in JSON
  186. format (referred to as a "JSON-dictionary"). Metadata extraction is
  187. defined by a dictionary called metadata_extract_fns, where each
  188. metadata field name is mapped to an extraction function.
  189. These extraction functions are designed to take a JSON-dictionary
  190. as their only argument and return the corresponding metadata.
  191. While extraction functions are permitted to raise exceptions, they
  192. should only raise a KeyError or ValueError if the metadata field
  193. cannot be extracted from the current JSON-dictionary, yet there's
  194. a possibility of finding it in another JSON-dictionary.
  195. The function returns a dictionary that maps metadata fields to
  196. their extracted data. The keys of this dictionary correspond exactly
  197. to those in metadata_extract_fns. If any fields fail to be extracted,
  198. their corresponding values are set to None, and a warning is printed.
  199. """
  200. if not os.path.isdir(quantized_model_dir):
  201. raise FileNotFoundError(
  202. f"The quantized model directory `{quantized_model_dir}` "
  203. "does not exist.")
  204. metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))
  205. result: Dict[str, Any] = {}
  206. for file in metadata_files:
  207. with open(file) as f:
  208. try:
  209. metadata = json.load(f)
  210. except json.JSONDecodeError:
  211. print(f"Could not parse `{file}` as a valid metadata file,"
  212. " skipping it.")
  213. continue
  214. if not isinstance(metadata, dict):
  215. print(f"The file `{file}` does not correspond to a "
  216. "JSON-serialized dictionary, skipping it.")
  217. continue
  218. for metadata_name, extract_fn in metadata_extract_fns.items():
  219. try:
  220. metadata_info = extract_fn(metadata)
  221. if metadata_name not in result:
  222. result[metadata_name] = metadata_info
  223. elif metadata_info != result[metadata_name]:
  224. raise RuntimeError(
  225. "Metadata mismatch! Originally found "
  226. f"{metadata_name} = {result[metadata_name]} but "
  227. f"now found {metadata_name} = {metadata_info} in "
  228. f"`{file}`")
  229. except KeyError:
  230. # It is possible that a given file does not contain some
  231. # of our selected metadata as it could be located in some
  232. # other metadata file.
  233. # 'EFINAE': extract_fn failure is not an error.
  234. pass
  235. except ValueError:
  236. # See above.
  237. pass
  238. # Warn if we cannot find any of the requested metadata
  239. for metadata_name in metadata_extract_fns:
  240. if metadata_name not in result:
  241. print("WARNING: Unable to find requested metadata field "
  242. f"`{metadata_name}`, setting it to None.")
  243. result[metadata_name] = None
  244. return result
  245. def main(args):
  246. metadata_extract_fns = {
  247. "model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"],
  248. "tp_size": lambda json_dict: int(json_dict["tensor_parallel"]),
  249. "model_dtype": lambda json_dict: json_dict["dtype"]
  250. }
  251. recovered_metadata = _metadata_extractor(args.quantized_model,
  252. metadata_extract_fns)
  253. if args.tp_size is not None:
  254. metadata_tp_size = recovered_metadata["tp_size"]
  255. if metadata_tp_size is not None:
  256. assert args.tp_size == metadata_tp_size, \
  257. f"User expected TP world size = {args.tp_size} " \
  258. f"but found TP world size = {metadata_tp_size} from metadata!"
  259. expected_tp_size = args.tp_size or recovered_metadata["tp_size"]
  260. rank_keyword = "rank"
  261. hf_tensor_files, use_safetensors = _prepare_hf_weights(
  262. args.quantized_model, args.load_format)
  263. rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors,
  264. rank_keyword, expected_tp_size)
  265. # Postprocess: formatting to the current schema. Consider pulling it
  266. # out into a dedicated function should it ever become more complicated.
  267. rank_scales_map = {
  268. rank: {k: scale[k]
  269. for k in sorted(scale.keys())}
  270. for rank, scale in rank_scales_map.items()
  271. }
  272. # TODO: Expand this with activation and weights scaling factors when
  273. # they are used in the future
  274. schema = QuantParamSchema(
  275. model_type=recovered_metadata["model_type"],
  276. kv_cache={
  277. "dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else
  278. recovered_metadata["model_dtype"]),
  279. "scaling_factor":
  280. rank_scales_map
  281. },
  282. )
  283. if args.output_dir is None:
  284. output_file = os.path.join(args.quantized_model, args.output_name)
  285. else:
  286. if not os.path.isdir(args.output_dir):
  287. os.makedirs(args.output_dir, exist_ok=True)
  288. output_file = os.path.join(args.output_dir, args.output_name)
  289. with open(output_file, 'w') as f:
  290. f.write(schema.model_dump_json(indent=4))
  291. print(f"Completed! KV cache scaling factors saved to {output_file}")
  292. if __name__ == "__main__":
  293. parser = argparse.ArgumentParser(
  294. description="This simple utility extracts the "
  295. "KV cache scaling factors from a quantized HF model "
  296. "and saves them to a JSON file compatible with later "
  297. "use by Aphrodite (pass this file to the appropriate "
  298. "runtime typically using the argument "
  299. "--quantization-param-path <filename>). This is only used "
  300. "if the KV cache dtype is FP8 and on ROCm (AMD GPU).")
  301. parser.add_argument(
  302. "--quantized-model",
  303. help="Specify the directory containing a single quantized HF model. "
  304. "It is expected that the quantization format is FP8_E4M3, for use "
  305. "on ROCm (AMD GPU).",
  306. required=True)
  307. parser.add_argument(
  308. "--load_format",
  309. help="Optionally specify the format of the model's tensor files "
  310. "containing the KV cache scaling factors.",
  311. choices=["auto", "safetensors", "npz", "pt"],
  312. default="auto")
  313. parser.add_argument(
  314. "--output-dir",
  315. help="Optionally specify the output directory. By default the "
  316. "KV cache scaling factors will be saved in the model directory, "
  317. "however you can override this behavior here.",
  318. default=None)
  319. parser.add_argument(
  320. "--output-name",
  321. help="Optionally specify the output filename.",
  322. # TODO: Change this once additional scaling factors are enabled
  323. default="kv_cache_scales.json")
  324. parser.add_argument(
  325. "--tp-size",
  326. help="Optionally specify the tensor-parallel (TP) size that the "
  327. "quantized model should correspond to. If specified, during KV "
  328. "cache scaling factor extraction the observed TP size will be "
  329. "checked against this and an error will be raised if there is "
  330. "a mismatch. If not specified, the quantized model's expected "
  331. "TP size is instead inferred from the largest TP rank observed. "
  332. "The expected TP size is cross-checked against the TP ranks "
  333. "observed in the quantized model and an error is raised if any "
  334. "discrepancies are found.",
  335. default=None,
  336. type=int)
  337. args = parser.parse_args()
  338. main(args)