calibration.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from functools import partial
  3. from typing import Union
  4. import torch
  5. from torch import nn
  6. import transformers
  7. from transformers import PreTrainedTokenizer
  8. from pkg_resources import parse_version
  9. from aphrodite.kv_quant.utils import (bimap_name_mod, collect_target_modules,
  10. concat_decoder_layer_outputs,
  11. split_decoder_layer_inputs)
  12. from aphrodite.kv_quant.observer import ActivationObserver, KVCacheObserver
  13. class CalibrationContext():
  14. """Calibration context manager for model quantization.
  15. Parameters:
  16. - model: The target model to be calibrated and quantized
  17. - tokenizer: The tokenizer used in the model training
  18. - layer_type: Layer type to be targeted for calibration
  19. - norm_type: Normalization type used for calibration
  20. - device: Device on which model is to be calibrated ('cpu' or 'cuda')
  21. """
  22. inp_obs_group = 'inputs'
  23. out_obs_group = 'outputs'
  24. key_obs_group = 'keys'
  25. value_obs_group = 'values'
  26. def __init__(self,
  27. model: nn.Module,
  28. tokenizer: PreTrainedTokenizer,
  29. layer_type: Union[str, type],
  30. norm_type: Union[str, type],
  31. device: str = 'cuda') -> None:
  32. """Initiate calibration context.
  33. Args:
  34. model (nn.Module): Model to be calibrated.
  35. tokenizer (PreTrainedTokenizer): Tokenizer of the given model.
  36. layer_type (Union[str, type]): Type of the layers to be observed.
  37. norm_type (Union[str, type]): Norm type used in the model.
  38. device (str, optional): Device where the model should run.
  39. Defaults to 'cuda'.
  40. """
  41. self.layer_type = layer_type
  42. self.norm_type = norm_type
  43. num_kv_heads, num_attn_heads = self._guess_num_heads(model)
  44. self.num_kv_heads = num_kv_heads
  45. self.head_dim = model.config.hidden_size // num_attn_heads
  46. self.model = model
  47. del self.model.lm_head
  48. self.tokenizer = tokenizer
  49. # Collect modules to observe
  50. self.name2layer = collect_target_modules(self.model, layer_type)
  51. self.name2fc = {}
  52. for l_name, layer in self.name2layer.items():
  53. name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name)
  54. self.name2fc.update(name2fc)
  55. self.name2norm = collect_target_modules(self.model, norm_type)
  56. maps = bimap_name_mod([self.name2layer, self.name2fc, self.name2norm])
  57. self.name2mod, self.mod2name = maps
  58. # Initialize observers
  59. self._init_input_observers(self.name2fc)
  60. self._init_output_observers(self.name2norm)
  61. self._init_output_observers(self.name2fc)
  62. self._init_kv_observers(self.name2layer)
  63. self.device = device
  64. def _guess_num_heads(self, model):
  65. if hasattr(model.config, 'num_key_value_heads'):
  66. num_kv_heads = model.config.num_key_value_heads
  67. else:
  68. num_kv_heads = model.config.num_attention_heads
  69. num_attn_heads = model.config.num_attention_heads
  70. return num_kv_heads, num_attn_heads
  71. def _init_input_observers(self, name2mod):
  72. """Initialize input observers for given modules."""
  73. for name, mod in name2mod.items():
  74. obs = ActivationObserver(mod.weight.size(-1))
  75. obs.global_available(name, group=self.inp_obs_group)
  76. def _init_output_observers(self, name2mod):
  77. """Initialize output observers for given modules."""
  78. for name, mod in name2mod.items():
  79. obs = ActivationObserver(mod.weight.size(0))
  80. obs.global_available(name, group=self.out_obs_group)
  81. def _init_kv_observers(self, name2mod):
  82. """Initialize KV observers for given modules."""
  83. for name in name2mod:
  84. k_obs = KVCacheObserver(self.num_kv_heads, self.head_dim)
  85. v_obs = KVCacheObserver(self.num_kv_heads, self.head_dim)
  86. k_obs.global_available(name, group=self.key_obs_group)
  87. v_obs.global_available(name, group=self.value_obs_group)
  88. def _insert_input_observers(self):
  89. """Insert input observers into the target modules.
  90. This function registers a forward pre-hook on each target module to
  91. observe the inputs.
  92. """
  93. def _input_hook(mod: nn.Module, inp: torch.Tensor):
  94. m_name = self.mod2name[mod]
  95. obs = ActivationObserver.find(m_name, group=self.inp_obs_group)
  96. obs.observe(inp[0])
  97. group = ActivationObserver.find_group(self.inp_obs_group)
  98. for name in group:
  99. mod = self.name2mod[name]
  100. hook_fn = mod.register_forward_pre_hook(_input_hook)
  101. self._hooks.append(hook_fn)
  102. def _insert_output_observers(self):
  103. """Insert output observers into the target modules.
  104. This function registers a forward hook on each target module to observe
  105. the outputs.
  106. """
  107. def _output_hook(mod: nn.Module, inp: torch.Tensor, out: torch.Tensor):
  108. m_name = self.mod2name[mod]
  109. obs = ActivationObserver.find(m_name, group=self.out_obs_group)
  110. obs.observe(out)
  111. group = ActivationObserver.find_group(self.out_obs_group)
  112. for name in group:
  113. mod = self.name2mod[name]
  114. hook_fn = mod.register_forward_hook(_output_hook)
  115. self._hooks.append(hook_fn)
  116. def _wrap_decoder_layers(self):
  117. """Method to wrap the decoder layers' forward functions for observing
  118. their key/value cache during batched forward passes."""
  119. def _forward(mod, *args, **kwargs):
  120. mod.to(self.device)
  121. batch_args, batch_kwargs = split_decoder_layer_inputs(
  122. *args, **kwargs)
  123. batch_outputs = []
  124. samples = len(batch_args)
  125. m_name = self.mod2name[mod]
  126. k_obs = KVCacheObserver.find(m_name, group=self.key_obs_group)
  127. v_obs = KVCacheObserver.find(m_name, group=self.value_obs_group)
  128. for i in range(len(batch_args)):
  129. if k_obs and v_obs:
  130. batch_kwargs[i]['use_cache'] = True
  131. version = parse_version(transformers.__version__)
  132. use_new_cache = type(mod).__name__ == 'LlamaDecoderLayer'
  133. if version > parse_version('4.36.0') and use_new_cache:
  134. from transformers.cache_utils import DynamicCache
  135. batch_kwargs[i]['past_key_value'] = DynamicCache()
  136. ori_idx = mod.self_attn.layer_idx
  137. mod.self_attn.layer_idx = 0
  138. out = self._ori_forwards[mod](*batch_args[i],
  139. **batch_kwargs[i])
  140. mod.self_attn.layer_idx = ori_idx
  141. out = list(out)
  142. cache = out.pop(-1)
  143. key = cache.key_cache.pop(-1)
  144. value = cache.value_cache.pop(-1)
  145. k_obs.observe(key)
  146. v_obs.observe(value)
  147. else:
  148. out = self._ori_forwards[mod](*batch_args[i],
  149. **batch_kwargs[i])
  150. out = list(out)
  151. key, value = out.pop(-1)
  152. k_obs.observe(key)
  153. v_obs.observe(value)
  154. del key, value
  155. torch.cuda.empty_cache()
  156. batch_outputs.append(tuple(out))
  157. else:
  158. batch_outputs.append(self._ori_forwards[mod](
  159. *batch_args[i], **batch_kwargs[i]))
  160. outputs = concat_decoder_layer_outputs(batch_outputs)
  161. del batch_outputs, batch_args, batch_kwargs, args
  162. mod.to('cpu')
  163. torch.cuda.empty_cache()
  164. max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
  165. print(f'{m_name}, samples: {samples}, '
  166. f'max gpu memory: {max_memory:.2f} GB')
  167. return outputs
  168. for layer in self.name2layer.values():
  169. self._ori_forwards[layer] = layer.forward
  170. layer.forward = partial(_forward, layer)
  171. def collect_inputs_stats(self):
  172. """Collect statistics (min, max, absmax values) of the observed inputs.
  173. Returns a dictionary with these collected stats.
  174. """
  175. inputs_stats = {
  176. 'max': {},
  177. 'min': {},
  178. 'mean': {},
  179. 'absmax': {},
  180. 'absmean': {}
  181. }
  182. obs_group = ActivationObserver.find_group(self.inp_obs_group)
  183. for name, obs in obs_group.items():
  184. inputs_stats['max'][name] = obs.max_val
  185. inputs_stats['min'][name] = obs.min_val
  186. inputs_stats['mean'][name] = obs.mean_val
  187. inputs_stats['absmax'][name] = obs.absmax_val
  188. inputs_stats['absmean'][name] = obs.absmean_val
  189. return inputs_stats
  190. def collect_outputs_stats(self):
  191. """Collect statistics (min, max, absmax values) of the observed
  192. outputs.
  193. Returns a dictionary with these collected stats.
  194. """
  195. outputs_stats = {
  196. 'max': {},
  197. 'min': {},
  198. 'mean': {},
  199. 'absmax': {},
  200. 'absmean': {}
  201. }
  202. obs_group = ActivationObserver.find_group(self.out_obs_group)
  203. for name, obs in obs_group.items():
  204. outputs_stats['max'][name] = obs.max_val
  205. outputs_stats['min'][name] = obs.min_val
  206. outputs_stats['mean'][name] = obs.mean_val
  207. outputs_stats['absmax'][name] = obs.absmax_val
  208. outputs_stats['absmean'][name] = obs.absmean_val
  209. return outputs_stats
  210. def collect_kv_stats(self):
  211. """Collect statistics (min, max, absmax values) of the observed keys
  212. and values.
  213. Returns a tuple of two dictionaries with these collected stats.
  214. """
  215. key_stats = {'max': {}, 'min': {}, 'absmax': {}}
  216. obs_group = KVCacheObserver.find_group(self.key_obs_group)
  217. for name, obs in obs_group.items():
  218. key_stats['max'][name] = obs.max_val
  219. key_stats['min'][name] = obs.min_val
  220. key_stats['absmax'][name] = obs.absmax_val
  221. value_stats = {'max': {}, 'min': {}, 'absmax': {}}
  222. obs_group = KVCacheObserver.find_group(self.value_obs_group)
  223. for name, obs in obs_group.items():
  224. value_stats['max'][name] = obs.max_val
  225. value_stats['min'][name] = obs.min_val
  226. value_stats['absmax'][name] = obs.absmax_val
  227. return key_stats, value_stats
  228. def export(self, out_dir):
  229. """Export the calibration statistics (inputs, outputs, keys and values)
  230. to specified directory.
  231. Args:
  232. out_dir (Union[str, Path]): The directory path where the stats
  233. will be saved.
  234. """
  235. inp_stats = self.collect_inputs_stats()
  236. torch.save(inp_stats, out_dir / 'inputs_stats.pth')
  237. out_stats = self.collect_outputs_stats()
  238. torch.save(out_stats, out_dir / 'outputs_stats.pth')
  239. key_stats, value_stats = self.collect_kv_stats()
  240. torch.save(key_stats, out_dir / 'key_stats.pth')
  241. torch.save(value_stats, out_dir / 'value_stats.pth')
  242. def calibrate(self, data):
  243. """Forward pass through the model in inference mode with given data."""
  244. if type(self.model).__name__ == 'QWenLMHeadModel':
  245. model = self.model.transformer
  246. else:
  247. model = self.model.model
  248. with torch.inference_mode():
  249. _ = model(data.to(self.device))
  250. def __enter__(self):
  251. """Prepares the Calibration object for a 'with' statement by
  252. registering hooks and wrapping layer forward methods."""
  253. self._hooks = list()
  254. self._ori_forwards = {}
  255. for layer in self.name2layer.values():
  256. self._ori_forwards[layer] = layer.forward
  257. self._insert_input_observers()
  258. self._insert_output_observers()
  259. self._wrap_decoder_layers()
  260. def __exit__(self, exc_type, exc_value, traceback):
  261. """Clean up after a 'with' statement by removing registered hooks,
  262. restoring original forward methods, and if no exception occurred,
  263. collecting all gathered statistics and saving them."""
  264. for h in self._hooks:
  265. h.remove()
  266. for layer in self.name2layer.values():
  267. layer.forward = self._ori_forwards[layer]