123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367 |
- import argparse
- import glob
- import json
- import os
- from typing import Any, Callable, Dict, List, Optional, Tuple
- import numpy as np
- import torch
- from safetensors.torch import safe_open
- from aphrodite.quantization.schema import QuantParamSchema
- def _prepare_hf_weights(
- quantized_model_dir: str,
- load_format: str = "auto",
- fall_back_to_pt: bool = True,
- ) -> Tuple[List[str], bool]:
- if not os.path.isdir(quantized_model_dir):
- raise FileNotFoundError(
- f"The quantized model directory `{quantized_model_dir}` "
- "does not exist.")
- use_safetensors = False
- if load_format == "auto":
- allow_patterns = ["*.safetensors", "*.bin"]
- elif load_format == "safetensors":
- use_safetensors = True
- allow_patterns = ["*.safetensors"]
- elif load_format == "pt":
- allow_patterns = ["*.pt"]
- elif load_format == "npz":
- allow_patterns = ["*.npz"]
- else:
- raise ValueError(f"Unknown load_format: {load_format}")
- if fall_back_to_pt:
- allow_patterns += ["*.pt"]
- hf_weights_files: List[str] = []
- for pattern in allow_patterns:
- hf_weights_files += glob.glob(
- os.path.join(quantized_model_dir, pattern))
- if len(hf_weights_files) > 0:
- if pattern == "*.safetensors":
- use_safetensors = True
- break
- if not use_safetensors:
- blacklist = [
- "training_args.bin",
- "optimizer.bin",
- "optimizer.pt",
- "scheduler.pt",
- "scaler.pt",
- ]
- hf_weights_files = [
- f for f in hf_weights_files
- if not any(f.endswith(x) for x in blacklist)
- ]
- if len(hf_weights_files) == 0:
- raise RuntimeError(
- f"Cannot find any model weights with `{quantized_model_dir}`")
- return hf_weights_files, use_safetensors
- def _hf_tensorfile_iterator(filename: str, load_format: str,
- use_safetensors: bool):
- if load_format == "npz":
- assert not use_safetensors
- with np.load(filename) as data:
- for name in data.files:
- param = torch.from_numpy(data[name])
- yield name, param
- elif use_safetensors:
- with safe_open(filename, framework="pt") as f:
- for name in f.keys():
- param = f.get_tensor(name)
- yield name, param
- else:
- state = torch.load(filename, map_location="cpu")
- for name, param in state.items():
- yield name, param
- del state
- torch.cuda.empty_cache()
- def _kv_scales_extractor(
- hf_tensor_files: List[str],
- use_safetensors: bool,
- rank_keyword: str = "rank",
- expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
- """
- Given a list of files containing tensor data, attempt to extract KV cache
- scales from these files. Intended as a helper function taking in the output
- from _prepare_hf_weights.
- Args:
- rank_keyword Matches the number immediately after this keyword in the
- tensor filename to determine the TP rank corresponding
- to said tensor file
- expected_tp_size If specified, the TP size of the tensor files is checked
- against this and an error is raised if they don't match.
- Returns a dictionary mapping TP ranks to their relevant KV cache scales.
- The per-rank scales are themselves represented as a dictionary of layer
- indices to the respective per-layer scale.
- """
- for char in rank_keyword:
- assert not char.isdecimal(
- ), f"Rank keyword {rank_keyword} contains a numeric character!"
- rank_scales_map: Dict[int, Dict[int, float]] = {}
- for tensor_file in hf_tensor_files:
- try:
- rank_idx = tensor_file.find(rank_keyword)
- if rank_idx != -1:
- start_idx = rank_idx + len(rank_keyword)
- stop_idx = start_idx
- while stop_idx < len(
- tensor_file) and tensor_file[stop_idx].isdecimal():
- stop_idx += 1
- if stop_idx == start_idx:
- raise RuntimeError("Did not find rank # in filename.")
- rank = int(tensor_file[start_idx:stop_idx])
- elif len(hf_tensor_files) == 1:
- rank = 0
- else:
- raise RuntimeError(
- f"Filename does not contain '{rank_keyword}'.")
- except RuntimeError:
- print("Unable to determine TP rank "
- f"corresponding to file '{tensor_file}'")
- raise
- if rank not in rank_scales_map:
- layer_scales_map: Dict[int, float] = {}
- rank_scales_map[rank] = layer_scales_map
- else:
- raise RuntimeError(
- f"Tensor file '{tensor_file}' shares TP rank {rank} "
- "with another tensor file.")
- module_delimiter = ":" if args.load_format == "npz" else "."
- for name, param in _hf_tensorfile_iterator(tensor_file,
- args.load_format,
- use_safetensors):
- if "kv_cache_scaling_factor" in name:
- nums = [
- int(s) for s in name.split(module_delimiter)
- if s.isdecimal()
- ]
- assert len(
- nums) == 1, f"Could not determine layer idx for {name}"
- layer_idx = nums[0]
- assert layer_idx not in layer_scales_map, f"Duplicate scaling"\
- f" factor corresponding to layer {layer_idx}"
- try:
- layer_scales_map[layer_idx] = param.item()
- except RuntimeError:
- print(
- "This utility supports only per-tensor scalar scales "
- f"for now. The tensor\n {name} = {param} \nis an "
- "invalid scale factor.")
- raise
- if all(
- len(layer_scales_map) == 0
- for layer_scales_map in rank_scales_map.values()):
- print("WARNING: No KV cache scale factors found. No output saved.")
- return None
- empirical_tp_world_size = max(rank_scales_map.keys()) + 1
- if expected_tp_size is not None:
- assert expected_tp_size == empirical_tp_world_size, \
- f"User expected TP world size = {expected_tp_size} " \
- "from model but tool is expecting TP world size = " \
- f"{empirical_tp_world_size} from model instead."
- for i in range(empirical_tp_world_size):
- assert i in rank_scales_map, "Expected TP world size = "\
- f"{empirical_tp_world_size} but did not find KV " \
- f"cache scaling factors for TP rank {i}"
- print(f"Found TP world size = {empirical_tp_world_size} "
- "when extracting KV cache scales!")
- return rank_scales_map
- def _metadata_extractor(quantized_model_dir: str,
- metadata_extract_fns: \
- Dict[str, Callable[[Dict[str, Any]], Any]]) \
- -> Dict[str, Any]:
- """
- Given a directory containing quantized model files, this function
- aims to extract metadata from the JSON files within this directory.
- Each JSON file is expected to represent a dictionary in JSON
- format (referred to as a "JSON-dictionary"). Metadata extraction is
- defined by a dictionary called metadata_extract_fns, where each
- metadata field name is mapped to an extraction function.
- These extraction functions are designed to take a JSON-dictionary
- as their only argument and return the corresponding metadata.
- While extraction functions are permitted to raise exceptions, they
- should only raise a KeyError or ValueError if the metadata field
- cannot be extracted from the current JSON-dictionary, yet there's
- a possibility of finding it in another JSON-dictionary.
- The function returns a dictionary that maps metadata fields to
- their extracted data. The keys of this dictionary correspond exactly
- to those in metadata_extract_fns. If any fields fail to be extracted,
- their corresponding values are set to None, and a warning is printed.
- """
- if not os.path.isdir(quantized_model_dir):
- raise FileNotFoundError(
- f"The quantized model directory `{quantized_model_dir}` "
- "does not exist.")
- metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))
- result: Dict[str, Any] = {}
- for file in metadata_files:
- with open(file) as f:
- try:
- metadata = json.load(f)
- except json.JSONDecodeError:
- print(f"Could not parse `{file}` as a valid metadata file,"
- " skipping it.")
- continue
- if not isinstance(metadata, dict):
- print(f"The file `{file}` does not correspond to a "
- "JSON-serialized dictionary, skipping it.")
- continue
- for metadata_name, extract_fn in metadata_extract_fns.items():
- try:
- metadata_info = extract_fn(metadata)
- if metadata_name not in result:
- result[metadata_name] = metadata_info
- elif metadata_info != result[metadata_name]:
- raise RuntimeError(
- "Metadata mismatch! Originally found "
- f"{metadata_name} = {result[metadata_name]} but "
- f"now found {metadata_name} = {metadata_info} in "
- f"`{file}`")
- except KeyError:
- pass
- except ValueError:
- pass
- for metadata_name in metadata_extract_fns:
- if metadata_name not in result:
- print("WARNING: Unable to find requested metadata field "
- f"`{metadata_name}`, setting it to None.")
- result[metadata_name] = None
- return result
- def main(args):
- metadata_extract_fns = {
- "model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"],
- "tp_size": lambda json_dict: int(json_dict["tensor_parallel"]),
- "model_dtype": lambda json_dict: json_dict["dtype"]
- }
- recovered_metadata = _metadata_extractor(args.quantized_model,
- metadata_extract_fns)
- if args.tp_size is not None:
- metadata_tp_size = recovered_metadata["tp_size"]
- if metadata_tp_size is not None:
- assert args.tp_size == metadata_tp_size, \
- f"User expected TP world size = {args.tp_size} " \
- f"but found TP world size = {metadata_tp_size} from metadata!"
- expected_tp_size = args.tp_size or recovered_metadata["tp_size"]
- rank_keyword = "rank"
- hf_tensor_files, use_safetensors = _prepare_hf_weights(
- args.quantized_model, args.load_format)
- rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors,
- rank_keyword, expected_tp_size)
- rank_scales_map = {
- rank: {k: scale[k]
- for k in sorted(scale.keys())}
- for rank, scale in rank_scales_map.items()
- }
- schema = QuantParamSchema(
- model_type=recovered_metadata["model_type"],
- kv_cache={
- "dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else
- recovered_metadata["model_dtype"]),
- "scaling_factor":
- rank_scales_map
- },
- )
- if args.output_dir is None:
- output_file = os.path.join(args.quantized_model, args.output_name)
- else:
- if not os.path.isdir(args.output_dir):
- os.makedirs(args.output_dir, exist_ok=True)
- output_file = os.path.join(args.output_dir, args.output_name)
- with open(output_file, 'w') as f:
- f.write(schema.model_dump_json(indent=4))
- print(f"Completed! KV cache scaling factors saved to {output_file}")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="This simple utility extracts the "
- "KV cache scaling factors from a quantized HF model "
- "and saves them to a JSON file compatible with later "
- "use by Aphrodite (pass this file to the appropriate "
- "runtime typically using the argument "
- "--quantization-param-path <filename>). This is only used "
- "if the KV cache dtype is FP8 and on ROCm (AMD GPU).")
- parser.add_argument(
- "--quantized-model",
- help="Specify the directory containing a single quantized HF model. "
- "It is expected that the quantization format is FP8_E4M3, for use "
- "on ROCm (AMD GPU).",
- required=True)
- parser.add_argument(
- "--load_format",
- help="Optionally specify the format of the model's tensor files "
- "containing the KV cache scaling factors.",
- choices=["auto", "safetensors", "npz", "pt"],
- default="auto")
- parser.add_argument(
- "--output-dir",
- help="Optionally specify the output directory. By default the "
- "KV cache scaling factors will be saved in the model directory, "
- "however you can override this behavior here.",
- default=None)
- parser.add_argument(
- "--output-name",
- help="Optionally specify the output filename.",
- default="kv_cache_scales.json")
- parser.add_argument(
- "--tp-size",
- help="Optionally specify the tensor-parallel (TP) size that the "
- "quantized model should correspond to. If specified, during KV "
- "cache scaling factor extraction the observed TP size will be "
- "checked against this and an error will be raised if there is "
- "a mismatch. If not specified, the quantized model's expected "
- "TP size is instead inferred from the largest TP rank observed. "
- "The expected TP size is cross-checked against the TP ranks "
- "observed in the quantized model and an error is raised if any "
- "discrepancies are found.",
- default=None,
- type=int)
- args = parser.parse_args()
- main(args)