calibrate.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from pathlib import Path
  3. import fire
  4. import torch
  5. from accelerate import (infer_auto_device_map, init_empty_weights,
  6. load_checkpoint_in_model)
  7. from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
  8. from aphrodite.kv_quant.calib_dataloader import get_calib_loaders
  9. from aphrodite.kv_quant.calibration import CalibrationContext
  10. from aphrodite.kv_quant.utils import collect_target_modules
  11. LAYER_TYPE_MAP = {
  12. 'InternLMForCausalLM': 'InternLMDecoderLayer',
  13. 'QWenLMHeadModel': 'QWenBlock',
  14. 'BaiChuanForCausalLM': 'DecoderLayer',
  15. 'LlamaForCausalLM': 'LlamaDecoderLayer',
  16. }
  17. NORM_TYPE_MAP = {
  18. 'InternLMForCausalLM': 'InternLMRMSNorm',
  19. 'QWenLMHeadModel': 'RMSNorm',
  20. 'BaiChuanForCausalLM': 'RMSNorm',
  21. 'LlamaForCausalLM': 'LlamaRMSNorm',
  22. }
  23. def calibrate(model: str,
  24. calib_dataset: str = 'c4',
  25. calib_samples: int = 128,
  26. calib_seqlen: int = 2048,
  27. work_dir: str = './work_dir',
  28. device: str = 'cuda',
  29. dataset_path: str = None) -> None:
  30. """The main function for loading the model and performing calibration on a
  31. given dataset.
  32. Args:
  33. model (str): The model to be loaded.
  34. calib_dataset (str, optional): The calibration dataset name.
  35. Defaults to 'c4'.
  36. calib_samples (int, optional): The number of samples for calibration.
  37. Defaults to 128.
  38. calib_seqlen (int, optional): The sequence length for calibration.
  39. Defaults to 2048.
  40. work_dir (str): The working directory for outputs.
  41. Defaults to './work_dir'.
  42. device (str, optional): The device to be used for calculation.
  43. Defaults to 'cuda'.
  44. """
  45. assert calib_dataset in ['c4', 'ptb', 'wikitext2', 'pileval'], \
  46. 'Support only `c4`, `ptb`, `wikitext2` or `pileval`.'
  47. # Load tokenizer and configuration
  48. tokenizer = AutoTokenizer.from_pretrained(model,
  49. use_fast=False,
  50. trust_remote_code=True)
  51. hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True)
  52. checkpoint = hf_config._name_or_path
  53. with init_empty_weights():
  54. # Load model
  55. model = AutoModelForCausalLM.from_pretrained(model,
  56. torch_dtype=torch.float16,
  57. trust_remote_code=True)
  58. model.config.use_cache = False
  59. layer_type = LAYER_TYPE_MAP[type(model).__name__]
  60. norm_type = NORM_TYPE_MAP[type(model).__name__]
  61. decoder_layers = collect_target_modules(model, layer_type)
  62. # Infer device map
  63. device_map = infer_auto_device_map(model,
  64. no_split_module_classes=[layer_type])
  65. for name in device_map:
  66. if name in decoder_layers or 'lm_head' in name:
  67. device_map[name] = 'cpu'
  68. else:
  69. device_map[name] = 0
  70. load_checkpoint_in_model(model, checkpoint, device_map)
  71. print('Loading calibrate dataset ...')
  72. calib_loader, _ = get_calib_loaders(calib_dataset,
  73. tokenizer,
  74. nsamples=calib_samples,
  75. seqlen=calib_seqlen,
  76. path=dataset_path)
  77. # Initialize calibration context
  78. calib_ctx = CalibrationContext(model,
  79. tokenizer,
  80. layer_type=layer_type,
  81. norm_type=norm_type,
  82. device=device)
  83. with calib_ctx:
  84. all_data = torch.cat([
  85. data if isinstance(data, torch.Tensor) else data[0]
  86. for data in calib_loader
  87. ]).to(device)
  88. calib_ctx.calibrate(all_data)
  89. # Create work directory if not exists
  90. work_dir = Path(work_dir)
  91. work_dir.mkdir(parents=True, exist_ok=True)
  92. calib_ctx.export(work_dir)
  93. if __name__ == '__main__':
  94. fire.Fire(calibrate)