quantize.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501
  2. # SPDX-License-Identifier: Apache-2.0
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Adapted from examples/quantization/hf_ptq.py
  17. """
  18. import argparse
  19. import copy
  20. import json
  21. import random
  22. import time
  23. import ammo.torch.quantization as atq
  24. import numpy as np
  25. import torch
  26. from ammo.torch.export import export_model_config
  27. from datasets import load_dataset
  28. from torch.utils.data import DataLoader
  29. from transformers import AutoModelForCausalLM, AutoTokenizer
  30. RAND_SEED = 1234
  31. MAX_SEQ_LEN = 2048
  32. EMPTY_CFG = {
  33. "quant_cfg": {
  34. "*weight_quantizer": {
  35. "enable": False,
  36. },
  37. "*input_quantizer": {
  38. "enable": False
  39. },
  40. "*lm_head*": {
  41. "enable": False
  42. },
  43. "*output_layer*": {
  44. "enable": False
  45. },
  46. "default": {
  47. "enable": False
  48. },
  49. },
  50. "algorithm": "max",
  51. }
  52. KV_CACHE_CFG = {
  53. "*.query_key_value.output_quantizer": {
  54. "num_bits": 8,
  55. "axis": None,
  56. "enable": True
  57. },
  58. "*.Wqkv.output_quantizer": {
  59. "num_bits": 8,
  60. "axis": None,
  61. "enable": True
  62. },
  63. "*.W_pack.output_quantizer": {
  64. "num_bits": 8,
  65. "axis": None,
  66. "enable": True
  67. },
  68. "*.c_attn.output_quantizer": {
  69. "num_bits": 8,
  70. "axis": None,
  71. "enable": True
  72. },
  73. "*.k_proj.output_quantizer": {
  74. "num_bits": 8,
  75. "axis": None,
  76. "enable": True
  77. },
  78. "*.v_proj.output_quantizer": {
  79. "num_bits": 8,
  80. "axis": None,
  81. "enable": True
  82. },
  83. }
  84. QUANT_CFG_CHOICES = {
  85. "int8_sq": atq.INT8_SMOOTHQUANT_CFG,
  86. "fp8": atq.FP8_DEFAULT_CFG,
  87. "int4_awq": atq.INT4_AWQ_CFG,
  88. "w4a8_awq": atq.W4A8_AWQ_BETA_CFG,
  89. "int8_wo": EMPTY_CFG,
  90. "int4_wo": EMPTY_CFG,
  91. "full_prec": EMPTY_CFG,
  92. }
  93. MODEL_NAME_PATTERN_MAP = {
  94. "GPT2": "gpt2",
  95. "Xverse": "llama",
  96. "Llama": "llama",
  97. "Mistral": "llama",
  98. "GPTJ": "gptj",
  99. "FalconForCausalLM": "falcon",
  100. "RWForCausalLM": "falcon",
  101. "baichuan": "baichuan",
  102. "MPT": "mpt",
  103. "Bloom": "bloom",
  104. "ChatGLM": "chatglm",
  105. "QWen": "qwen",
  106. }
  107. def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None):
  108. print(f"Initializing tokenizer from {ckpt_path}")
  109. tokenizer = AutoTokenizer.from_pretrained(
  110. ckpt_path,
  111. model_max_length=max_seq_len,
  112. padding_side="left",
  113. trust_remote_code=True,
  114. )
  115. if model_type and model_type == "qwen":
  116. # qwen use token id 151643 as pad and eos tokens
  117. tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643)
  118. tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643)
  119. # can't set attribute 'pad_token' for "<unk>"
  120. if tokenizer.pad_token != "<unk>":
  121. tokenizer.pad_token = tokenizer.eos_token
  122. if tokenizer.pad_token is None:
  123. tokenizer.pad_token = tokenizer.eos_token
  124. assert (tokenizer.pad_token
  125. is not None), f"Pad token for {model_type} cannot be set!"
  126. return tokenizer
  127. def get_model(ckpt_path, dtype="fp16", device="cuda"):
  128. print(f"Initializing model from {ckpt_path}")
  129. if dtype == "bf16" or dtype == "bfloat16":
  130. dtype = torch.bfloat16
  131. elif dtype == "fp16" or dtype == "float16":
  132. dtype = torch.float16
  133. elif dtype == "fp32" or dtype == "float32":
  134. dtype = torch.float32
  135. else:
  136. raise NotImplementedError(f"Unknown dtype {dtype}")
  137. # model_kwargs = {"torch_dtype": dtype}
  138. model_kwargs = {"torch_dtype": "auto"}
  139. model = AutoModelForCausalLM.from_pretrained(ckpt_path,
  140. device_map="auto",
  141. **model_kwargs,
  142. trust_remote_code=True)
  143. model.eval()
  144. model_dtype = next(model.parameters()).dtype
  145. if dtype != model_dtype:
  146. print("[TensorRT-LLM][WARNING] The manually set model data type is "
  147. f"{dtype}, but the data type of the HuggingFace model is "
  148. f"{model_dtype}.")
  149. return model
  150. def get_model_type(model):
  151. for k, v in MODEL_NAME_PATTERN_MAP.items():
  152. if k.lower() in type(model).__name__.lower():
  153. return v
  154. return None
  155. def get_calib_dataloader(data="cnn_dailymail",
  156. tokenizer=None,
  157. batch_size=1,
  158. calib_size=512,
  159. block_size=512,
  160. device=None):
  161. print("Loading calibration dataset")
  162. if data == "pileval":
  163. dataset = load_dataset(
  164. "json",
  165. data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
  166. split="train")
  167. dataset = dataset["text"][:calib_size]
  168. elif data == "cnn_dailymail":
  169. dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
  170. dataset = dataset["article"][:calib_size]
  171. else:
  172. raise NotImplementedError
  173. batch_encoded = tokenizer.batch_encode_plus(dataset,
  174. return_tensors="pt",
  175. padding="max_length",
  176. truncation=True,
  177. max_length=block_size)
  178. if device:
  179. batch_encoded = batch_encoded.to(device)
  180. batch_encoded = batch_encoded["input_ids"]
  181. calib_dataloader = DataLoader(batch_encoded,
  182. batch_size=batch_size,
  183. shuffle=False)
  184. return calib_dataloader
  185. def quantize_model(model, quant_cfg, calib_dataloader=None):
  186. def calibrate_loop():
  187. if calib_dataloader is None:
  188. return
  189. """Adjusts weights and scaling factors based on selected algorithms."""
  190. for idx, data in enumerate(calib_dataloader):
  191. print(f"Calibrating batch {idx}")
  192. model(data)
  193. print("Starting quantization...")
  194. start_time = time.time()
  195. atq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
  196. end_time = time.time()
  197. print("Quantization done. Total time used: {:.2f} s.".format(end_time -
  198. start_time))
  199. return model
  200. def main(args):
  201. if not torch.cuda.is_available():
  202. raise EnvironmentError("GPU is required for inference.")
  203. random.seed(RAND_SEED)
  204. np.random.seed(RAND_SEED)
  205. model = get_model(args.model_dir, args.dtype, args.device)
  206. model_type = get_model_type(model)
  207. tokenizer = get_tokenizer(args.model_dir, model_type=model_type)
  208. if args.qformat in ["full_prec", "int8_wo", "int4_wo"
  209. ] and args.kv_cache_dtype is None:
  210. print(f"No quantization applied, export {args.dtype} model")
  211. else:
  212. if "awq" in args.qformat:
  213. if args.calib_size > 32:
  214. print("AWQ calibration could take longer with calib_size = "
  215. f"{args.calib_size}, Using calib_size=32 instead")
  216. args.calib_size = 32
  217. print("\nAWQ calibration could take longer than other calibration "
  218. "methods. Please increase the batch size to speed up the "
  219. "calibration process. Batch size can be set by adding the "
  220. "argument --batch_size <batch_size> to the command line.\n")
  221. calib_dataloader = get_calib_dataloader(
  222. tokenizer=tokenizer,
  223. batch_size=args.batch_size,
  224. calib_size=args.calib_size,
  225. device=args.device,
  226. )
  227. if args.qformat in QUANT_CFG_CHOICES:
  228. quant_cfg = QUANT_CFG_CHOICES[args.qformat]
  229. else:
  230. raise ValueError(
  231. f"Unsupported quantization format: {args.qformat}")
  232. if "awq" in args.qformat:
  233. quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat])
  234. weight_quantizer = quant_cfg["quant_cfg"][
  235. "*weight_quantizer"] # type: ignore
  236. if isinstance(weight_quantizer, list):
  237. weight_quantizer = weight_quantizer[0]
  238. weight_quantizer["block_sizes"][-1] = args.awq_block_size
  239. if args.kv_cache_dtype is not None:
  240. if args.kv_cache_dtype == "fp8":
  241. for value in KV_CACHE_CFG.values():
  242. value.update({"num_bits": (4, 3)}) # type: ignore
  243. quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore
  244. print(quant_cfg)
  245. model = quantize_model(model, quant_cfg, calib_dataloader)
  246. with torch.inference_mode():
  247. if model_type is None:
  248. print(f"Unknown model type {type(model).__name__}. Continue "
  249. "exporting...")
  250. model_type = f"unknown:{type(model).__name__}"
  251. export_path = args.output_dir
  252. start_time = time.time()
  253. if args.qformat == "int4_awq" and model_type == "qwen":
  254. torch.save(model.state_dict(), export_path)
  255. else:
  256. export_npz = (model_type not in [
  257. 'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan'
  258. ])
  259. # export safetensors
  260. export_model_config(
  261. model,
  262. model_type,
  263. getattr(torch, args.dtype),
  264. export_dir=export_path,
  265. inference_tensor_parallel=args.tp_size,
  266. inference_pipeline_parallel=args.pp_size,
  267. # export_tensorrt_llm_config=(not export_npz),
  268. export_tensorrt_llm_config=False,
  269. export_npz=export_npz)
  270. # Workaround for wo quantization
  271. if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
  272. with open(f"{export_path}/config.json", 'r') as f:
  273. tensorrt_llm_config = json.load(f)
  274. if args.qformat == "int8_wo":
  275. tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'
  276. elif args.qformat == "int4_wo":
  277. tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16'
  278. else:
  279. tensorrt_llm_config["quantization"]["quant_algo"] = None
  280. with open(f"{export_path}/config.json", "w") as f:
  281. json.dump(tensorrt_llm_config, f, indent=4)
  282. end_time = time.time()
  283. print("Quantized model exported to {} \nTotal time used {:.2f} s.".
  284. format(export_path, end_time - start_time))
  285. if __name__ == "__main__":
  286. parser = argparse.ArgumentParser(description=__doc__)
  287. parser.add_argument("--model-dir",
  288. help="Specify where the HuggingFace model is",
  289. required=True)
  290. parser.add_argument("--device", default="cuda")
  291. parser.add_argument("--dtype", help="Model data type.", default="float16")
  292. parser.add_argument(
  293. "--qformat",
  294. help="Quantization format.",
  295. default="full_prec",
  296. choices=[
  297. "fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo",
  298. "full_prec"
  299. ],
  300. )
  301. parser.add_argument("--batch-size",
  302. help="Batch size for calibration.",
  303. type=int,
  304. default=1)
  305. parser.add_argument("--calib-size",
  306. help="Number of samples for calibration.",
  307. type=int,
  308. default=512)
  309. parser.add_argument("--output-dir", default="exported_model")
  310. parser.add_argument("--tp-size", type=int, default=1)
  311. parser.add_argument("--pp-size", type=int, default=1)
  312. parser.add_argument("--awq-block-size", type=int, default=128)
  313. parser.add_argument("--kv-cache-dtype",
  314. help="KV Cache dtype.",
  315. default=None,
  316. choices=["int8", "fp8", None])
  317. args = parser.parse_args()
  318. main(args)