utils.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from typing import Optional
  2. import torch
  3. from aphrodite import quantization_ops
  4. from aphrodite.modeling.layers.quantized_linear.gptq import (
  5. GPTQColumnParallelLinear, GPTQRowParallelLinear, GPTQLinear)
  6. def quant_post_init(model, max_input_length: Optional[int] = None):
  7. device_to_buffers_size = {}
  8. model_uses_exllama = False
  9. use_act_order = False
  10. for _, submodule in model.named_modules():
  11. if isinstance(submodule,
  12. (GPTQColumnParallelLinear, GPTQRowParallelLinear,
  13. GPTQLinear)) and submodule.use_exllama:
  14. model_uses_exllama = True
  15. device = submodule.qweight.device
  16. if device not in device_to_buffers_size:
  17. device_to_buffers_size[device] = {
  18. "max_dq_buffer_size": 1,
  19. "max_inner_outer_dim": 1
  20. }
  21. device_to_buffers_size[device]["max_dq_buffer_size"] = max(
  22. device_to_buffers_size[device]["max_dq_buffer_size"],
  23. submodule.qweight.numel() * 8)
  24. in_features = submodule.input_size_per_partition if isinstance(
  25. submodule, GPTQRowParallelLinear) else submodule.input_size
  26. out_features = submodule.output_size_per_partition if isinstance(
  27. submodule, GPTQColumnParallelLinear) else submodule.output_size
  28. if submodule.quant_config.desc_act:
  29. use_act_order = True
  30. device_to_buffers_size[device]["max_inner_outer_dim"] = max(
  31. device_to_buffers_size[device]["max_inner_outer_dim"],
  32. in_features, out_features)
  33. if model_uses_exllama:
  34. device_to_buffers = {}
  35. max_input_len = max_input_length if use_act_order else 1
  36. for device, buffers_size in device_to_buffers_size.items():
  37. device_to_buffers[device] = {
  38. "temp_state":
  39. torch.zeros(
  40. (max_input_len, buffers_size["max_inner_outer_dim"]),
  41. dtype=torch.float16,
  42. device=device),
  43. "temp_dq":
  44. torch.zeros((1, buffers_size["max_dq_buffer_size"]),
  45. dtype=torch.float16,
  46. device=device),
  47. "max_dq_buffer_size":
  48. buffers_size["max_dq_buffer_size"],
  49. "max_inner_outer_dim":
  50. buffers_size["max_inner_outer_dim"],
  51. }
  52. # buffers need to be persistent to avoid any bugs
  53. model.device_to_buffers = device_to_buffers
  54. for device, buffers in model.device_to_buffers.items():
  55. quantization_ops.gptq_prepare_buffers(device,
  56. buffers["temp_state"],
  57. buffers["temp_dq"])
  58. matmul_recons_thd = 8
  59. matmul_fused_remap = False
  60. matmul_no_half2 = False
  61. quantization_ops.gptq_set_tuning_params(matmul_recons_thd,
  62. matmul_fused_remap,
  63. matmul_no_half2)
  64. # the buffers need to have been initialized first before calling make_q4
  65. for _, submodule in model.named_modules():
  66. if isinstance(
  67. submodule,
  68. (GPTQColumnParallelLinear, GPTQRowParallelLinear, GPTQLinear)):
  69. submodule.post_init()
  70. torch.cuda.empty_cache()
  71. return model