1
0

marlin_utils_test.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. """Utility functions used for tests and benchmarks"""
  2. from typing import List
  3. import numpy as np
  4. import torch
  5. from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales,
  6. marlin_zero_points)
  7. from .quant_utils import (get_pack_factor, quantize_weights,
  8. quantize_weights_with_zp, sort_weights)
  9. class MarlinWorkspace:
  10. def __init__(self, out_features, min_thread_n, max_parallel):
  11. assert (out_features % min_thread_n == 0), (
  12. "out_features = {} is undivisible by min_thread_n = {}".format(
  13. out_features, min_thread_n))
  14. max_workspace_size = ((out_features // min_thread_n) * max_parallel)
  15. self.scratch = torch.zeros(max_workspace_size,
  16. dtype=torch.int,
  17. device="cuda")
  18. def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
  19. assert q_w.shape == (size_k, size_n)
  20. assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
  21. assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
  22. # Permute weights to 16x64 marlin tiles
  23. q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
  24. q_w = q_w.permute((0, 2, 1, 3))
  25. q_w = q_w.reshape((size_k // tile, size_n * tile))
  26. q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
  27. return q_w
  28. def marlin_weights(q_w, size_k, size_n, num_bits, perm):
  29. # Permute
  30. q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
  31. # Pack
  32. pack_factor = get_pack_factor(num_bits)
  33. orig_device = q_w.device
  34. q_w = q_w.cpu().numpy().astype(np.uint32)
  35. q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
  36. dtype=np.uint32)
  37. for i in range(pack_factor):
  38. q_packed |= q_w[:, i::pack_factor] << num_bits * i
  39. q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
  40. return q_packed
  41. def get_weight_perm(num_bits: int):
  42. perm_list: List[int] = []
  43. for i in range(32):
  44. perm1: List[int] = []
  45. col = i // 4
  46. for block in [0, 1]:
  47. for row in [
  48. 2 * (i % 4),
  49. 2 * (i % 4) + 1,
  50. 2 * (i % 4 + 4),
  51. 2 * (i % 4 + 4) + 1,
  52. ]:
  53. perm1.append(16 * row + col + 8 * block)
  54. for j in range(4):
  55. perm_list.extend([p + 256 * j for p in perm1])
  56. perm = np.array(perm_list)
  57. if num_bits == 4:
  58. interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
  59. elif num_bits == 8:
  60. interleave = np.array([0, 2, 1, 3])
  61. else:
  62. raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
  63. perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
  64. perm = torch.from_numpy(perm)
  65. return perm
  66. def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
  67. act_order: bool):
  68. size_k, size_n = w.shape
  69. # Normalize group_size
  70. if group_size == -1:
  71. group_size = size_k
  72. assert group_size <= size_k
  73. # Quantize (and apply act_order if provided)
  74. w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
  75. act_order)
  76. # For act_order, sort the "weights" and "g_idx" so that group ids are
  77. # increasing
  78. sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
  79. if act_order:
  80. q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
  81. # Reformat to marlin
  82. weight_perm = get_weight_perm(num_bits)
  83. marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
  84. marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
  85. # Create result
  86. res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
  87. for i in range(len(res_list)):
  88. res_list[i] = res_list[i].to(w.device)
  89. return res_list
  90. def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
  91. size_k, size_n = w.shape
  92. # Normalize group_size
  93. if group_size == -1:
  94. group_size = size_k
  95. assert group_size <= size_k
  96. # Detect num groups
  97. assert size_k % group_size == 0
  98. num_groups = size_k // group_size
  99. # Quantize with zp
  100. w_ref, q_w, s, zp = quantize_weights_with_zp(w, num_bits, group_size)
  101. # Reformat to marlin
  102. weight_perm = get_weight_perm(num_bits)
  103. marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
  104. marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
  105. marlin_zp = marlin_zero_points(zp, num_groups, size_n, num_bits)
  106. # Create result
  107. res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
  108. for i in range(len(res_list)):
  109. res_list[i] = res_list[i].to(w.device)
  110. return res_list