export_kv_params.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from pathlib import Path
  3. from typing import Union
  4. import fire
  5. import numpy as np
  6. import torch
  7. def _export_sym(key_stats: dict,
  8. value_stats: dict,
  9. bits: int,
  10. out_dir: Union[str, Path],
  11. tp: int = 1) -> None:
  12. """Export symmetric quantization parameters to specified directory."""
  13. keys_absmax = key_stats['absmax']
  14. values_absmax = value_stats['absmax']
  15. for layer_idx, name in enumerate(keys_absmax.keys()):
  16. k_absmax = keys_absmax[name]
  17. v_absmax = values_absmax[name]
  18. heads, _ = k_absmax.shape
  19. assert heads % tp == 0
  20. mp_k_absmax = torch.chunk(k_absmax, tp)
  21. mp_v_absmax = torch.chunk(v_absmax, tp)
  22. for i in range(tp):
  23. # quant: q = f / scale
  24. # dequant: f = q * scale
  25. k_s = mp_k_absmax[i].max() / (2**(bits - 1) - 1)
  26. v_s = mp_v_absmax[i].max() / (2**(bits - 1) - 1)
  27. kv_qparams = np.array([k_s, v_s], dtype=np.float32)
  28. out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' # noqa: E501
  29. kv_qparams.tofile(out_path)
  30. print(f'Layer {layer_idx} MP {i} qparam: {k_s} \t{v_s}')
  31. def _export_asym(key_stats: dict,
  32. value_stats: dict,
  33. bits: int,
  34. out_dir: Union[str, Path],
  35. tp: int = 1) -> None:
  36. """Export asymmetric quantization parameters to specified directory."""
  37. keys_min = key_stats['min']
  38. values_min = value_stats['min']
  39. keys_max = key_stats['max']
  40. values_max = value_stats['max']
  41. for layer_idx, name in enumerate(keys_min.keys()):
  42. k_max = keys_max[name]
  43. v_max = values_max[name]
  44. k_min = keys_min[name]
  45. v_min = values_min[name]
  46. heads, _ = k_min.shape
  47. assert heads % tp == 0
  48. tp_k_min = torch.chunk(k_min, tp)
  49. tp_v_min = torch.chunk(v_min, tp)
  50. tp_k_max = torch.chunk(k_max, tp)
  51. tp_v_max = torch.chunk(v_max, tp)
  52. for i in range(tp):
  53. # zp = (min+max) / 2
  54. # scale = (max-min) / 255
  55. # quant: q = (f-zp) / scale
  56. # dequant: f = q * scale + zp
  57. k_min = tp_k_min[i].min()
  58. v_min = tp_v_min[i].min()
  59. k_max = tp_k_max[i].max()
  60. v_max = tp_v_max[i].max()
  61. k_scale = (k_max - k_min) / (2**bits - 1)
  62. v_scale = (v_max - v_min) / (2**bits - 1)
  63. k_zp = (k_max + k_min) / 2
  64. v_zp = (v_max + v_min) / 2
  65. kv_qparams = np.array([k_scale, k_zp, v_scale, v_zp],
  66. dtype=np.float32)
  67. out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight'
  68. kv_qparams.tofile(out_path)
  69. print(f'Layer {layer_idx} MP {i} qparam: '
  70. f'\t{k_scale} \t{k_zp} \t{v_scale} \t{v_zp}')
  71. def main(work_dir: str,
  72. kv_params_dir: str,
  73. kv_bits: int = 8,
  74. kv_sym: bool = False,
  75. num_tp: int = 1) -> None:
  76. """Main function to export key and value stats.
  77. Args:
  78. work_dir (Union[str, Path]): Directory path where the stats are saved.
  79. kv_params_dir (Union[str, Path]): Directory path where to
  80. save the results.
  81. kv_bits (int, optional): Number of bits for quantization.
  82. Defaults to 8.
  83. kv_sym (bool, optional): Whether to use symmetric quantizaiton.
  84. Defaults to False.
  85. num_tp (int, optional): Number of tensor parallelism. Defaults to 1.
  86. """
  87. work_dir = Path(work_dir)
  88. tm_dir = Path(kv_params_dir)
  89. tm_dir.mkdir(parents=True, exist_ok=True)
  90. key_stats = torch.load(work_dir / 'key_stats.pth')
  91. value_stats = torch.load(work_dir / 'value_stats.pth')
  92. if kv_sym:
  93. _export_sym(key_stats, value_stats, kv_bits, tm_dir, num_tp)
  94. else:
  95. _export_asym(key_stats, value_stats, kv_bits, tm_dir, num_tp)
  96. if __name__ == '__main__':
  97. fire.Fire(main)