1
0

marlin_utils_test_qqq.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from typing import List
  2. import numpy
  3. import torch
  4. from .marlin_utils_test import marlin_permute_weights
  5. from .quant_utils import get_pack_factor, qqq_quantize_weights
  6. def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
  7. # Permute
  8. q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
  9. # Pack
  10. pack_factor = get_pack_factor(num_bits)
  11. orig_device = q_w.device
  12. q_w = q_w.cpu().numpy().astype(numpy.uint32)
  13. q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
  14. dtype=numpy.uint32)
  15. if group_size == size_k:
  16. for i in range(pack_factor):
  17. q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
  18. else:
  19. for i in range(pack_factor):
  20. q_packed |= q_w[:, i::pack_factor] << num_bits * i
  21. q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
  22. return q_packed
  23. def get_qqq_scale_perms():
  24. scale_perm: List[int] = []
  25. for i in range(8):
  26. scale_perm.extend([i + 8 * j for j in range(8)])
  27. scale_perm_single: List[int] = []
  28. for i in range(4):
  29. scale_perm_single.extend(
  30. [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
  31. return scale_perm, scale_perm_single
  32. # NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
  33. def get_qqq_weight_perm(num_bits: int, quant_type: str):
  34. perm_list: List[int] = []
  35. for i in range(32):
  36. perm1: List[int] = []
  37. col = i // 4
  38. for block in [0, 1]:
  39. for row in [
  40. 4 * (i % 4),
  41. 4 * (i % 4) + 1,
  42. 4 * (i % 4) + 2,
  43. 4 * (i % 4) + 3,
  44. ]:
  45. perm1.append(16 * row + col + 8 * block)
  46. for j in range(4):
  47. perm_list.extend([p + 256 * j for p in perm1])
  48. perm = numpy.array(perm_list)
  49. assert quant_type in ["per-channel",
  50. "per-group"], "not supported quantization type"
  51. if num_bits == 4:
  52. if quant_type == "per-channel":
  53. interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
  54. else:
  55. interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
  56. else:
  57. raise Exception("num_bits must be 4, got {}".format(num_bits))
  58. perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
  59. perm = torch.from_numpy(perm)
  60. return perm
  61. def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
  62. scale_perm, scale_perm_single = get_qqq_scale_perms()
  63. if group_size < size_k and group_size != -1:
  64. s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
  65. s_channel = s_channel.reshape(
  66. (-1, len(scale_perm_single)))[:, scale_perm_single]
  67. s_group = s_group.reshape((-1, size_n)).contiguous()
  68. else:
  69. s_channel = s_channel.reshape(
  70. (-1, len(scale_perm_single)))[:, scale_perm_single]
  71. s_channel = s_channel.reshape((-1, size_n)).contiguous()
  72. return s_group, s_channel
  73. def marlin_qqq_quantize(
  74. w: torch.Tensor,
  75. num_bits: int,
  76. group_size: int,
  77. ):
  78. size_k, size_n = w.shape
  79. # Normalize group_size
  80. if group_size == -1:
  81. group_size = size_k
  82. assert group_size <= size_k
  83. quant_type = "per-channel" if group_size == size_k else "per-group"
  84. # Quantize
  85. w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
  86. w, num_bits, group_size)
  87. # Reformat to marlin_qqq
  88. weight_perm = get_qqq_weight_perm(num_bits, quant_type)
  89. marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
  90. weight_perm, group_size)
  91. marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
  92. s_group, s_channel, size_k, size_n, group_size)
  93. # Create result
  94. res_list = [
  95. w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
  96. ]
  97. for i in range(len(res_list)):
  98. res_list[i] = res_list[i].to(w.device)
  99. return res_list