quant_utils.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import numpy
  2. import torch
  3. SUPPORTED_NUM_BITS = [4, 8]
  4. SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  5. def get_pack_factor(num_bits):
  6. assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
  7. return 32 // num_bits
  8. def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
  9. assert q_w.shape == w_ref.shape
  10. orig_device = q_w.device
  11. k_size, _ = q_w.shape
  12. g_idx = torch.zeros((k_size, ), dtype=torch.int32)
  13. for i in range(k_size):
  14. g_idx[i] = i // group_size
  15. # Simulate act_order by doing a random permutation on K
  16. rand_perm = torch.randperm(k_size)
  17. g_idx = g_idx[rand_perm].contiguous()
  18. q_w = q_w[rand_perm, :].contiguous()
  19. w_ref = w_ref[rand_perm, :].contiguous()
  20. return (
  21. w_ref.to(device=orig_device),
  22. q_w.to(device=orig_device),
  23. g_idx.to(device=orig_device),
  24. rand_perm.to(device=orig_device),
  25. )
  26. def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
  27. act_order: bool):
  28. orig_device = w.device
  29. size_k, size_n = w.shape
  30. assert w.is_floating_point(), "w must be float"
  31. assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
  32. assert group_size in SUPPORTED_GROUP_SIZES + [
  33. size_k
  34. ], f"Unsupported groupsize = {group_size}"
  35. if group_size == -1:
  36. group_size = size_k
  37. assert group_size <= size_k
  38. max_q_val = 2**num_bits - 1
  39. half_q_val = (max_q_val + 1) // 2
  40. # Reshape to [groupsize, -1]
  41. if group_size < size_k:
  42. w = w.reshape((-1, group_size, size_n))
  43. w = w.permute(1, 0, 2)
  44. w = w.reshape((group_size, -1))
  45. # Compute scale for each group
  46. s = torch.max(torch.abs(w), 0, keepdim=True)[0]
  47. s *= 2 / max_q_val # 2 => symmetric
  48. # Quantize
  49. q_w = torch.round(w / s).int()
  50. q_w += half_q_val
  51. q_w = torch.clamp(q_w, 0, max_q_val)
  52. # Compute ref (dequantized)
  53. w_ref = (q_w - half_q_val).half() * s
  54. # Restore original shapes
  55. if group_size < size_k:
  56. def reshape_w(w):
  57. w = w.reshape((group_size, -1, size_n))
  58. w = w.permute(1, 0, 2)
  59. w = w.reshape((size_k, size_n)).contiguous()
  60. return w
  61. q_w = reshape_w(q_w)
  62. w_ref = reshape_w(w_ref)
  63. s = s.reshape((-1, size_n)).contiguous()
  64. # Apply act_order
  65. g_idx = torch.empty(0, dtype=torch.int, device=w.device)
  66. rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
  67. if act_order:
  68. assert (
  69. group_size < size_k
  70. ), "For act_order, groupsize = {} must be less than size_k = {}".format(
  71. group_size, size_k)
  72. w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
  73. return (
  74. w_ref.to(device=orig_device),
  75. q_w.to(device=orig_device),
  76. s.to(device=orig_device),
  77. g_idx.to(device=orig_device),
  78. rand_perm.to(device=orig_device),
  79. )
  80. def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
  81. orig_device = q_w.device
  82. sort_indices = torch.argsort(g_idx).to(
  83. dtype=torch.int32) # Sort based on g_idx
  84. g_idx = g_idx[sort_indices].contiguous()
  85. q_w = q_w[sort_indices, :].contiguous()
  86. return (
  87. q_w.to(device=orig_device),
  88. g_idx.to(device=orig_device),
  89. sort_indices.to(device=orig_device),
  90. )
  91. def gptq_pack(
  92. q_w: torch.Tensor,
  93. num_bits: int,
  94. size_k: int,
  95. size_n: int,
  96. ):
  97. assert q_w.shape == (size_k, size_n)
  98. pack_factor = get_pack_factor(num_bits)
  99. assert size_k % pack_factor == 0
  100. orig_device = q_w.device
  101. q_w = q_w.cpu().numpy().astype(numpy.uint32)
  102. q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
  103. for i in range(pack_factor):
  104. q_res |= q_w[i::pack_factor, :] << num_bits * i
  105. q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
  106. return q_res