quant_utils.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. """This file is used for /tests and /benchmarks"""
  2. import numpy
  3. import torch
  4. SUPPORTED_NUM_BITS = [4, 8]
  5. SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  6. def get_pack_factor(num_bits):
  7. assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
  8. return 32 // num_bits
  9. def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
  10. assert q_w.shape == w_ref.shape
  11. orig_device = q_w.device
  12. k_size, _ = q_w.shape
  13. g_idx = torch.zeros((k_size, ), dtype=torch.int32)
  14. for i in range(k_size):
  15. g_idx[i] = i // group_size
  16. # Simulate act_order by doing a random permutation on K
  17. rand_perm = torch.randperm(k_size)
  18. g_idx = g_idx[rand_perm].contiguous()
  19. q_w = q_w[rand_perm, :].contiguous()
  20. w_ref = w_ref[rand_perm, :].contiguous()
  21. return (
  22. w_ref.to(device=orig_device),
  23. q_w.to(device=orig_device),
  24. g_idx.to(device=orig_device),
  25. rand_perm.to(device=orig_device),
  26. )
  27. def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
  28. act_order: bool):
  29. orig_device = w.device
  30. size_k, size_n = w.shape
  31. assert w.is_floating_point(), "w must be float"
  32. assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
  33. assert group_size in SUPPORTED_GROUP_SIZES + [
  34. size_k
  35. ], f"Unsupported groupsize = {group_size}"
  36. if group_size == -1:
  37. group_size = size_k
  38. assert group_size <= size_k
  39. max_q_val = 2**num_bits - 1
  40. half_q_val = (max_q_val + 1) // 2
  41. # Reshape to [groupsize, -1]
  42. if group_size < size_k:
  43. w = w.reshape((-1, group_size, size_n))
  44. w = w.permute(1, 0, 2)
  45. w = w.reshape((group_size, -1))
  46. # Compute scale for each group
  47. s = torch.max(torch.abs(w), 0, keepdim=True)[0]
  48. s *= 2 / max_q_val # 2 => symmetric
  49. # Quantize
  50. q_w = torch.round(w / s).int()
  51. q_w += half_q_val
  52. q_w = torch.clamp(q_w, 0, max_q_val)
  53. # Compute ref (dequantized)
  54. w_ref = (q_w - half_q_val).half() * s
  55. # Restore original shapes
  56. if group_size < size_k:
  57. def reshape_w(w):
  58. w = w.reshape((group_size, -1, size_n))
  59. w = w.permute(1, 0, 2)
  60. w = w.reshape((size_k, size_n)).contiguous()
  61. return w
  62. q_w = reshape_w(q_w)
  63. w_ref = reshape_w(w_ref)
  64. s = s.reshape((-1, size_n)).contiguous()
  65. # Apply act_order
  66. g_idx = torch.empty(0, dtype=torch.int, device=w.device)
  67. rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
  68. if act_order:
  69. assert (
  70. group_size < size_k
  71. ), "For act_order, groupsize = {} must be less than size_k = {}".format(
  72. group_size, size_k)
  73. w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
  74. return (
  75. w_ref.to(device=orig_device),
  76. q_w.to(device=orig_device),
  77. s.to(device=orig_device),
  78. g_idx.to(device=orig_device),
  79. rand_perm.to(device=orig_device),
  80. )
  81. def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
  82. orig_device = w.device
  83. size_k, size_n = w.shape
  84. assert w.is_floating_point(), "w must be float"
  85. assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
  86. assert group_size in SUPPORTED_GROUP_SIZES + [
  87. size_k
  88. ], f"Unsupported groupsize = {group_size}"
  89. if group_size == -1:
  90. group_size = size_k
  91. assert group_size <= size_k
  92. max_q_val = 2**num_bits - 1
  93. min_q_val = 0
  94. # Reshape to [groupsize, -1]
  95. if group_size < size_k:
  96. w = w.reshape((-1, group_size, size_n))
  97. w = w.permute(1, 0, 2)
  98. w = w.reshape((group_size, -1))
  99. # Compute scale for each group
  100. max = torch.max(w, 0, keepdim=True)[0]
  101. min = torch.min(w, 0, keepdim=True)[0]
  102. s = (max - min).clamp(min=1e-5) / max_q_val
  103. # Compute zero-point for each group
  104. zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int()
  105. # Quantize
  106. q_w = torch.round(w / s).int() + zp
  107. q_w = torch.clamp(q_w, min_q_val, max_q_val)
  108. # Compute ref (dequantized)
  109. w_ref = (q_w - zp).half() * s
  110. # Restore original shapes
  111. if group_size < size_k:
  112. def reshape_w(w):
  113. w = w.reshape((group_size, -1, size_n))
  114. w = w.permute(1, 0, 2)
  115. w = w.reshape((size_k, size_n)).contiguous()
  116. return w
  117. q_w = reshape_w(q_w)
  118. w_ref = reshape_w(w_ref)
  119. s = s.reshape((-1, size_n)).contiguous()
  120. zp = zp.reshape((-1, size_n)).contiguous()
  121. return (
  122. w_ref.to(device=orig_device),
  123. q_w.to(device=orig_device),
  124. s.to(device=orig_device),
  125. zp.to(device=orig_device),
  126. )
  127. def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
  128. orig_device = q_w.device
  129. sort_indices = torch.argsort(g_idx).to(
  130. dtype=torch.int32) # Sort based on g_idx
  131. g_idx = g_idx[sort_indices].contiguous()
  132. q_w = q_w[sort_indices, :].contiguous()
  133. return (
  134. q_w.to(device=orig_device),
  135. g_idx.to(device=orig_device),
  136. sort_indices.to(device=orig_device),
  137. )
  138. def pack_rows(
  139. q_w: torch.Tensor,
  140. num_bits: int,
  141. size_k: int,
  142. size_n: int,
  143. ):
  144. assert q_w.shape == (size_k, size_n)
  145. pack_factor = get_pack_factor(num_bits)
  146. assert size_k % pack_factor == 0
  147. orig_device = q_w.device
  148. q_w = q_w.cpu().numpy().astype(numpy.uint32)
  149. q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
  150. for i in range(pack_factor):
  151. q_res |= q_w[i::pack_factor, :] << num_bits * i
  152. q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
  153. return q_res
  154. def pack_cols(
  155. q_w: torch.Tensor,
  156. num_bits: int,
  157. size_k: int,
  158. size_n: int,
  159. ):
  160. assert q_w.shape == (size_k, size_n)
  161. pack_factor = get_pack_factor(num_bits)
  162. assert size_n % pack_factor == 0
  163. orig_device = q_w.device
  164. q_w = q_w.cpu().numpy().astype(numpy.uint32)
  165. q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
  166. for i in range(pack_factor):
  167. q_res |= q_w[:, i::pack_factor] << num_bits * i
  168. q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
  169. q_res = q_res.contiguous()
  170. return q_res
  171. def unpack_cols(
  172. packed_q_w: torch.Tensor,
  173. num_bits: int,
  174. size_k: int,
  175. size_n: int,
  176. ):
  177. pack_factor = get_pack_factor(num_bits)
  178. assert size_n % pack_factor == 0
  179. assert packed_q_w.shape == (
  180. size_k, size_n // pack_factor
  181. ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
  182. packed_q_w.shape, size_k, size_n, pack_factor)
  183. orig_device = packed_q_w.device
  184. packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
  185. q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
  186. mask = (1 << num_bits) - 1
  187. for i in range(pack_factor):
  188. vals = packed_q_w_cpu & mask
  189. packed_q_w_cpu >>= num_bits
  190. q_res[:, i::pack_factor] = vals
  191. q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
  192. q_res = q_res.contiguous()
  193. return q_res
  194. def gptq_pack(
  195. q_w: torch.Tensor,
  196. num_bits: int,
  197. size_k: int,
  198. size_n: int,
  199. ):
  200. return pack_rows(q_w, num_bits, size_k, size_n)
  201. def awq_pack(
  202. q_w: torch.Tensor,
  203. num_bits: int,
  204. size_k: int,
  205. size_n: int,
  206. ):
  207. assert q_w.shape == (size_k, size_n)
  208. # Interleave column dim (for the dequantize code) and pack it to int32
  209. if num_bits == 4:
  210. interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
  211. elif num_bits == 8:
  212. interleave = numpy.array([0, 2, 1, 3])
  213. else:
  214. raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
  215. q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
  216. q_w = q_w.reshape((-1, size_n)).contiguous()
  217. return pack_cols(q_w, num_bits, size_k, size_n)