marlin_utils_test.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. """Utility functions used for tests and benchmarks"""
  2. from typing import List
  3. import numpy as np
  4. import torch
  5. from aphrodite.scalar_type import ScalarType
  6. from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales,
  7. marlin_zero_points)
  8. from .quant_utils import (get_pack_factor, gptq_quantize_weights,
  9. quantize_weights, sort_weights)
  10. class MarlinWorkspace:
  11. def __init__(self, out_features, min_thread_n, max_parallel):
  12. assert (out_features % min_thread_n == 0), (
  13. "out_features = {} is undivisible by min_thread_n = {}".format(
  14. out_features, min_thread_n))
  15. max_workspace_size = ((out_features // min_thread_n) * max_parallel)
  16. self.scratch = torch.zeros(max_workspace_size,
  17. dtype=torch.int,
  18. device="cuda")
  19. def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
  20. assert q_w.shape == (size_k, size_n)
  21. assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
  22. assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
  23. # Permute weights to 16x64 marlin tiles
  24. q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
  25. q_w = q_w.permute((0, 2, 1, 3))
  26. q_w = q_w.reshape((size_k // tile, size_n * tile))
  27. q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
  28. return q_w
  29. def marlin_weights(q_w, size_k, size_n, num_bits, perm):
  30. # Permute
  31. q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
  32. # Pack
  33. pack_factor = get_pack_factor(num_bits)
  34. orig_device = q_w.device
  35. q_w = q_w.cpu().numpy().astype(np.uint32)
  36. q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
  37. dtype=np.uint32)
  38. for i in range(pack_factor):
  39. q_packed |= q_w[:, i::pack_factor] << num_bits * i
  40. q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
  41. return q_packed
  42. def get_weight_perm(num_bits: int):
  43. perm_list: List[int] = []
  44. for i in range(32):
  45. perm1: List[int] = []
  46. col = i // 4
  47. for block in [0, 1]:
  48. for row in [
  49. 2 * (i % 4),
  50. 2 * (i % 4) + 1,
  51. 2 * (i % 4 + 4),
  52. 2 * (i % 4 + 4) + 1,
  53. ]:
  54. perm1.append(16 * row + col + 8 * block)
  55. for j in range(4):
  56. perm_list.extend([p + 256 * j for p in perm1])
  57. perm = np.array(perm_list)
  58. if num_bits == 4:
  59. interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
  60. elif num_bits == 8:
  61. interleave = np.array([0, 2, 1, 3])
  62. else:
  63. raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
  64. perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
  65. perm = torch.from_numpy(perm)
  66. return perm
  67. def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
  68. act_order: bool):
  69. size_k, size_n = w.shape
  70. num_bits = quant_type.size_bits
  71. # Normalize group_size
  72. if group_size == -1:
  73. group_size = size_k
  74. assert group_size <= size_k
  75. # Quantize (and apply act_order if provided)
  76. w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
  77. w, quant_type, group_size, act_order)
  78. # For act_order, sort the "weights" and "g_idx" so that group ids are
  79. # increasing
  80. sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
  81. if act_order:
  82. q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
  83. # Reformat to marlin
  84. weight_perm = get_weight_perm(num_bits)
  85. marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
  86. marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
  87. # Create result
  88. res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
  89. for i in range(len(res_list)):
  90. res_list[i] = res_list[i].to(w.device)
  91. return res_list
  92. def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType,
  93. group_size: int):
  94. size_k, size_n = w.shape
  95. # Normalize group_size
  96. if group_size == -1:
  97. group_size = size_k
  98. assert group_size <= size_k
  99. # Detect num groups
  100. assert size_k % group_size == 0
  101. num_groups = size_k // group_size
  102. # Quantize with zp
  103. w_ref, q_w, s, zp = quantize_weights(w,
  104. quant_type,
  105. group_size,
  106. zero_points=True)
  107. # Reformat to marlin
  108. weight_perm = get_weight_perm(quant_type.size_bits)
  109. marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
  110. weight_perm)
  111. marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
  112. marlin_zp = marlin_zero_points(zp, num_groups, size_n,
  113. quant_type.size_bits)
  114. # Create result
  115. res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
  116. for i in range(len(res_list)):
  117. res_list[i] = res_list[i].to(w.device)
  118. return res_list