quant_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. """This file is used for /tests and /benchmarks"""
  2. from typing import List
  3. import numpy
  4. import torch
  5. from aphrodite.quantization.qqq import MARLIN_QQQ_SUPPORTED_NUM_BITS
  6. from aphrodite.scalar_type import ScalarType, scalar_types
  7. SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
  8. SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  9. # NOTE: this is a hack. We should update each model to register the
  10. # stacked params and get it from there instead in a future PR.
  11. # fused_name: List[shard_name]
  12. FUSED_LAYER_NAME_MAPPING = {
  13. "qkv_proj": ["q_proj", "k_proj", "v_proj"],
  14. "gate_up_proj": ["gate_proj", "up_proj"]
  15. }
  16. def pack_weights_into_int32(w_q: torch.Tensor,
  17. wtype: ScalarType,
  18. packed_dim: int = 0):
  19. # move dim to pack to the end
  20. perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
  21. inv_perm = tuple(perm.index(i) for i in range(len(perm)))
  22. w_q_perm = w_q.permute(perm)
  23. pack_factor = 32 // wtype.size_bits
  24. mask = (1 << wtype.size_bits) - 1
  25. new_shape_perm = list(w_q_perm.shape)
  26. assert w_q_perm.shape[-1] % pack_factor == 0
  27. new_shape_perm[-1] //= pack_factor
  28. res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
  29. for i in range(pack_factor):
  30. res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
  31. return res.permute(inv_perm)
  32. def unpack_weights_into_int32(w_q: torch.Tensor,
  33. wtype: ScalarType,
  34. packed_dim: int = 0):
  35. # move dim to pack to the end
  36. perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
  37. inv_perm = tuple(perm.index(i) for i in range(len(perm)))
  38. w_q_perm = w_q.permute(perm)
  39. pack_factor = 32 // wtype.size_bits
  40. mask = (1 << wtype.size_bits) - 1
  41. new_shape_perm = list(w_q_perm.shape)
  42. new_shape_perm[-1] *= pack_factor
  43. res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
  44. for i in range(pack_factor):
  45. res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
  46. return res.permute(inv_perm)
  47. def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
  48. # prefix: model.layers.0.self_attn.q_proj
  49. # proj_name: q_proj
  50. proj_name = prefix.split(".")[-1]
  51. if proj_name in FUSED_LAYER_NAME_MAPPING:
  52. shard_prefixes = [
  53. prefix.replace(proj_name, shard_proj_name)
  54. for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
  55. ]
  56. is_skipped = None
  57. for shard_prefix in shard_prefixes:
  58. is_shard_skipped = shard_prefix in ignored_layers
  59. if is_skipped is None:
  60. is_skipped = is_shard_skipped
  61. elif is_shard_skipped != is_skipped:
  62. raise ValueError(
  63. f"Detected some but not all shards of {prefix} "
  64. "are quantized. All shards of fused layers "
  65. "to have the same precision.")
  66. else:
  67. is_skipped = prefix in ignored_layers
  68. assert is_skipped is not None
  69. return is_skipped
  70. def get_pack_factor(num_bits):
  71. assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
  72. return 32 // num_bits
  73. def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
  74. assert q_w.shape == w_ref.shape
  75. orig_device = q_w.device
  76. k_size, _ = q_w.shape
  77. g_idx = torch.zeros((k_size, ), dtype=torch.int32)
  78. for i in range(k_size):
  79. g_idx[i] = i // group_size
  80. # Simulate act_order by doing a random permutation on K
  81. rand_perm = torch.randperm(k_size)
  82. g_idx = g_idx[rand_perm].contiguous()
  83. q_w = q_w[rand_perm, :].contiguous()
  84. w_ref = w_ref[rand_perm, :].contiguous()
  85. return (
  86. w_ref.to(device=orig_device),
  87. q_w.to(device=orig_device),
  88. g_idx.to(device=orig_device),
  89. rand_perm.to(device=orig_device),
  90. )
  91. def quantize_weights(w: torch.Tensor,
  92. quant_type: ScalarType,
  93. group_size: int,
  94. zero_points: bool = False,
  95. ref_zero_points_after_scales: bool = False):
  96. assert quant_type.is_integer(), \
  97. "Floating point quantization may work but has not been tested"
  98. orig_device = w.device
  99. orig_type = w.dtype
  100. size_k, size_n = w.shape
  101. assert w.is_floating_point(), "w must be float"
  102. if group_size == -1:
  103. group_size = size_k
  104. assert group_size <= size_k
  105. # Reshape to [groupsize, -1]
  106. if group_size < size_k:
  107. w = w.reshape((-1, group_size, size_n))
  108. w = w.permute(1, 0, 2)
  109. w = w.reshape((group_size, -1))
  110. # Compute scale for each group
  111. max_val = torch.max(w, 0, keepdim=True).values
  112. min_val = torch.min(w, 0, keepdim=True).values
  113. max_q_val = quant_type.max()
  114. min_q_val = quant_type.min()
  115. if zero_points:
  116. assert not quant_type.is_signed() and quant_type.max() > 0
  117. w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
  118. maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
  119. .clamp(min_q_val, max_q_val).int()
  120. else:
  121. # If the bias is such that there are no possible negative/positive
  122. # values, set the max value to inf to avoid divide by 0
  123. w_s = torch.max(
  124. abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
  125. abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
  126. maybe_w_zp = None
  127. # Quantize
  128. w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
  129. w_q = torch.clamp(w_q, min_q_val, max_q_val)
  130. # For some kernels (namely Machete) the zero-points are applied after the
  131. # scales are applied, for this case computing the reference in similar way
  132. # allows us to use tighter error tolerances in our unit tests.
  133. if ref_zero_points_after_scales and zero_points:
  134. w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
  135. else:
  136. w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
  137. if quant_type.has_bias():
  138. w_q += quant_type.bias
  139. # Restore original shapes
  140. if group_size < size_k:
  141. def reshape_w(w):
  142. w = w.reshape((group_size, -1, size_n))
  143. w = w.permute(1, 0, 2)
  144. w = w.reshape((size_k, size_n)).contiguous()
  145. return w
  146. w_q = reshape_w(w_q)
  147. w_ref = reshape_w(w_ref)
  148. w_s = w_s.reshape((-1, size_n)).contiguous()
  149. if zero_points:
  150. maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
  151. maybe_w_zp = maybe_w_zp.to(device=orig_device)
  152. return (
  153. w_ref.to(device=orig_device),
  154. w_q.to(device=orig_device),
  155. w_s.to(device=orig_device),
  156. maybe_w_zp,
  157. )
  158. def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
  159. group_size: int, act_order: bool):
  160. size_k, _ = w.shape
  161. assert w.is_floating_point(), "w must be float"
  162. assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \
  163. f"Unsupported gptq type = {quant_type}"
  164. assert group_size in SUPPORTED_GROUP_SIZES + [
  165. size_k
  166. ], f"Unsupported groupsize = {group_size}"
  167. w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
  168. # Apply act_order
  169. g_idx = torch.empty(0, dtype=torch.int, device=w.device)
  170. rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
  171. if act_order:
  172. assert (
  173. group_size < size_k
  174. ), "For act_order, groupsize = {} must be less than size_k = {}".format(
  175. group_size, size_k)
  176. w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
  177. return w_ref, w_q, w_s, g_idx, rand_perm
  178. # QQQ employs different quant schemes for per-group and
  179. # per-channel quantization.
  180. def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
  181. orig_device = w.device
  182. size_k, size_n = w.shape
  183. assert w.is_floating_point(), "w must be float"
  184. assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \
  185. f"Unsupported num_bits = {num_bits}"
  186. assert group_size in SUPPORTED_GROUP_SIZES + [
  187. size_k
  188. ], f"Unsupported groupsize = {group_size}"
  189. if group_size == -1:
  190. group_size = size_k
  191. assert group_size <= size_k
  192. if group_size < size_k:
  193. # Reshape to [groupsize, -1]
  194. w = w.reshape((-1, group_size, size_n))
  195. w = w.permute(1, 0, 2)
  196. w = w.reshape((group_size, -1))
  197. max_q_val = 2**num_bits - 1
  198. half_q_val = (max_q_val + 1) // 2
  199. # Compute scale for each group
  200. s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
  201. s_group *= 2 / max_q_val # 2 => symmetric
  202. # Quantize
  203. q_w = torch.round(w / s_group).int()
  204. q_w += half_q_val
  205. q_w = torch.clamp(q_w, 0, max_q_val)
  206. # Compute ref (dequantized)
  207. w_ref = (q_w - half_q_val).half() * s_group
  208. # Restore original shapes
  209. def reshape_w(w):
  210. w = w.reshape((group_size, -1, size_n))
  211. w = w.permute(1, 0, 2)
  212. w = w.reshape((size_k, size_n)).contiguous()
  213. return w
  214. q_w = reshape_w(q_w)
  215. w_ref = reshape_w(w_ref)
  216. # Compute int8 quantization scale for each channel
  217. s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
  218. s_channel /= 127.0
  219. t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
  220. w_ref = t_int8.half() * s_channel
  221. s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
  222. # Fuse scales
  223. s_group = (s_group.reshape(-1, size_n).contiguous() /
  224. s_channel).to(dtype=torch.half)
  225. else:
  226. max_q_val = 2**(num_bits - 1) - 1
  227. # Compute scale for each channel
  228. s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
  229. s_channel /= max_q_val
  230. # Quantize
  231. q_w = torch.round(w / s_channel).int()
  232. q_w = torch.clamp(q_w, -max_q_val, max_q_val)
  233. # Compute ref (dequantized)
  234. w_ref = q_w.half() * s_channel
  235. s_group = torch.tensor([], dtype=torch.half)
  236. # div 2 ** (8 - self.bits)) to offset right shift in unpacking
  237. s_channel /= (2**(8 - num_bits))
  238. s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
  239. return (
  240. w_ref.to(device=orig_device),
  241. q_w.to(device=orig_device),
  242. s_group.to(device=orig_device),
  243. s_channel.to(device=orig_device),
  244. )
  245. def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
  246. orig_device = q_w.device
  247. sort_indices = torch.argsort(g_idx).to(
  248. dtype=torch.int32) # Sort based on g_idx
  249. g_idx = g_idx[sort_indices].contiguous()
  250. q_w = q_w[sort_indices, :].contiguous()
  251. return (
  252. q_w.to(device=orig_device),
  253. g_idx.to(device=orig_device),
  254. sort_indices.to(device=orig_device),
  255. )
  256. def pack_rows(
  257. q_w: torch.Tensor,
  258. num_bits: int,
  259. size_k: int,
  260. size_n: int,
  261. ):
  262. assert q_w.shape == (size_k, size_n)
  263. pack_factor = get_pack_factor(num_bits)
  264. assert size_k % pack_factor == 0
  265. orig_device = q_w.device
  266. q_w = q_w.cpu().numpy().astype(numpy.uint32)
  267. q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
  268. for i in range(pack_factor):
  269. q_res |= q_w[i::pack_factor, :] << num_bits * i
  270. q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
  271. return q_res
  272. def pack_cols(
  273. q_w: torch.Tensor,
  274. num_bits: int,
  275. size_k: int,
  276. size_n: int,
  277. ):
  278. assert q_w.shape == (size_k, size_n)
  279. pack_factor = get_pack_factor(num_bits)
  280. assert size_n % pack_factor == 0
  281. orig_device = q_w.device
  282. q_w = q_w.cpu().numpy().astype(numpy.uint32)
  283. q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
  284. for i in range(pack_factor):
  285. q_res |= q_w[:, i::pack_factor] << num_bits * i
  286. q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
  287. q_res = q_res.contiguous()
  288. return q_res
  289. def unpack_cols(
  290. packed_q_w: torch.Tensor,
  291. num_bits: int,
  292. size_k: int,
  293. size_n: int,
  294. ):
  295. pack_factor = get_pack_factor(num_bits)
  296. assert size_n % pack_factor == 0
  297. assert packed_q_w.shape == (
  298. size_k, size_n // pack_factor
  299. ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
  300. packed_q_w.shape, size_k, size_n, pack_factor)
  301. orig_device = packed_q_w.device
  302. packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
  303. q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
  304. mask = (1 << num_bits) - 1
  305. for i in range(pack_factor):
  306. vals = packed_q_w_cpu & mask
  307. packed_q_w_cpu >>= num_bits
  308. q_res[:, i::pack_factor] = vals
  309. q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
  310. q_res = q_res.contiguous()
  311. return q_res
  312. def gptq_pack(
  313. q_w: torch.Tensor,
  314. num_bits: int,
  315. size_k: int,
  316. size_n: int,
  317. ):
  318. return pack_rows(q_w, num_bits, size_k, size_n)
  319. def awq_pack(
  320. q_w: torch.Tensor,
  321. num_bits: int,
  322. size_k: int,
  323. size_n: int,
  324. ):
  325. assert q_w.shape == (size_k, size_n)
  326. # Interleave column dim (for the dequantize code) and pack it to int32
  327. if num_bits == 4:
  328. interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
  329. elif num_bits == 8:
  330. interleave = numpy.array([0, 2, 1, 3])
  331. else:
  332. raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
  333. q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
  334. q_w = q_w.reshape((-1, size_n)).contiguous()
  335. return pack_cols(q_w, num_bits, size_k, size_n)