hqq_marlin.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. from typing import Any, Dict, List, Optional, Tuple
  2. import torch
  3. from aphrodite import _custom_ops as ops
  4. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  5. from aphrodite.modeling.parameter import (BaseAphroditeParameter,
  6. HQQQweightParameter,
  7. HQQZeroScaleParameter)
  8. from aphrodite.modeling.utils import set_weight_attrs
  9. from aphrodite.quantization.base_config import QuantizationConfig
  10. from aphrodite.quantization.utils.marlin_utils import (
  11. GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
  12. marlin_make_empty_g_idx, marlin_permute_scales)
  13. from aphrodite.quantization.utils.marlin_utils_test import MarlinWorkspace
  14. from aphrodite.quantization.utils.quant_utils import gptq_pack
  15. from aphrodite.scalar_type import scalar_types
  16. class HQQMarlinConfig(QuantizationConfig):
  17. """Config class for HQQ Marlin"""
  18. # (num_bits, is_sym) -> quant_type
  19. TYPE_MAP = {
  20. 4: scalar_types.uint4,
  21. 8: scalar_types.uint8,
  22. }
  23. def __init__(
  24. self,
  25. weight_bits: int,
  26. group_size: int,
  27. ) -> None:
  28. self.pack_factor = 8 // weight_bits # packed into uint8
  29. self.group_size = group_size
  30. self.quant_type = self.TYPE_MAP[(weight_bits)]
  31. def __repr__(self) -> str:
  32. return (f"HQQMarlinConfig(quant_type={self.quant_type}, "
  33. f"group_size={self.group_size})")
  34. @classmethod
  35. def get_name(cls) -> str:
  36. return "hqq"
  37. @classmethod
  38. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  39. return [torch.half, torch.bfloat16]
  40. @classmethod
  41. def get_min_capability(cls) -> int:
  42. return 80
  43. @classmethod
  44. def get_config_filenames(cls) -> List[str]:
  45. return ["quantize_config.json"]
  46. @classmethod
  47. def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig":
  48. wq_params = (config["quant_config"]["weight_quant_params"])
  49. weight_bits = cls.get_from_keys(wq_params, ["nbits"])
  50. group_size = cls.get_from_keys(wq_params, ["group_size"])
  51. return cls(weight_bits, group_size)
  52. @classmethod
  53. def override_quantization_method(cls, hf_quant_cfg,
  54. user_quant) -> Optional[str]:
  55. #TODO
  56. return None
  57. def get_quant_method(self, layer: torch.nn.Module,
  58. prefix: str) -> Optional["HQQMarlinMethod"]:
  59. if isinstance(layer, LinearBase):
  60. return HQQMarlinMethod(self)
  61. return None
  62. def get_scaled_act_names(self) -> List[str]:
  63. return []
  64. # Empty HQQ parameter, will be ignored during loading
  65. class HQQEmptyParameter(BaseAphroditeParameter):
  66. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  67. pass
  68. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  69. pass
  70. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  71. pass
  72. def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
  73. raise ValueError("No loader provided for HQQ parameter!")
  74. class HQQMarlinMethod(LinearMethodBase):
  75. """Linear method for HQQ Marlin.
  76. """
  77. def __init__(
  78. self,
  79. quant_config: HQQMarlinConfig,
  80. ):
  81. self.quant_config = quant_config
  82. def create_weights(
  83. self,
  84. layer: torch.nn.Module,
  85. input_size_per_partition: int,
  86. output_partition_sizes: List[int],
  87. input_size: int,
  88. output_size: int,
  89. params_dtype: torch.dtype,
  90. **extra_weight_attrs,
  91. ) -> None:
  92. self.output_size_per_partition = sum(output_partition_sizes)
  93. self.input_size_per_partition = input_size_per_partition
  94. weight_loader = extra_weight_attrs.get("weight_loader", error_loader)
  95. self.scales_and_zp_size = (input_size_per_partition //
  96. self.quant_config.group_size)
  97. # Quantized weights
  98. qweight = HQQQweightParameter(
  99. data=torch.empty(
  100. self.output_size_per_partition //
  101. self.quant_config.pack_factor,
  102. input_size_per_partition,
  103. dtype=torch.uint8,
  104. ),
  105. input_dim=1,
  106. output_dim=0,
  107. packed_dim=0,
  108. packed_factor=self.quant_config.pack_factor,
  109. weight_loader=weight_loader)
  110. set_weight_attrs(qweight, {
  111. "is_hqq_weight": True,
  112. "shard_offsets:": [],
  113. })
  114. zeros = HQQZeroScaleParameter(data=torch.empty(
  115. self.output_size_per_partition,
  116. self.scales_and_zp_size,
  117. dtype=params_dtype,
  118. ),
  119. input_dim=1,
  120. output_dim=0,
  121. weight_loader=weight_loader)
  122. scales = HQQZeroScaleParameter(data=torch.empty(
  123. self.output_size_per_partition,
  124. self.scales_and_zp_size,
  125. dtype=params_dtype,
  126. ),
  127. input_dim=1,
  128. output_dim=0,
  129. weight_loader=weight_loader)
  130. layer.register_parameter("W_q", qweight)
  131. layer.register_parameter("zero", zeros)
  132. layer.register_parameter("scale", scales)
  133. # Ignore extra parameters in the HQQ model.
  134. # To be added as needed.
  135. ignore_parameters = ("axis", "channel_wise", "compute_dtype",
  136. "encoded_state_dict", "group_size", "nbits",
  137. "offload_meta", "optimize", "packing",
  138. "quant_scale", "quant_zero", "round_zero",
  139. "shape", "stores_quant_config",
  140. "unpack_view_dtype", "view_as_float")
  141. for name in ignore_parameters:
  142. layer.register_parameter(
  143. name,
  144. HQQEmptyParameter(data=torch.empty(0),
  145. weight_loader=weight_loader))
  146. # Unpack weights from the HQQ format and repack them to GPTQ -> Marlin
  147. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  148. dev = layer.W_q.device
  149. # unpack function from https://github.com/mobiusml/hqq
  150. def unpack_4bit_u8(
  151. W_q: torch.Tensor,
  152. shard_offsets: List[Tuple[int, int]],
  153. ) -> torch.Tensor: # uint8/2 > uint8
  154. dtype = torch.uint8
  155. tmp = torch.empty([2 * W_q.shape[0], W_q.shape[1]],
  156. dtype=dtype,
  157. device=W_q.device)
  158. for (offset, size) in shard_offsets:
  159. tmp_offset = 2 * offset
  160. tmp[tmp_offset:tmp_offset +
  161. size] = (W_q[offset:offset + size] & 0b11110000) >> 4
  162. tmp[tmp_offset + size:tmp_offset +
  163. 2 * size] = (W_q[offset:offset + size] & 0b00001111)
  164. return tmp
  165. # Unpack from 4-bit to 8-bit
  166. shard_offsets = getattr(layer.W_q, "shard_offsets", [])
  167. qweight_t = unpack_4bit_u8(layer.W_q, shard_offsets).transpose(1, 0)
  168. # Repack to GPTQ
  169. gptq_w_q = gptq_pack(qweight_t, 4, self.input_size_per_partition,
  170. self.output_size_per_partition)
  171. # Repack to Marlin
  172. sort_indices = torch.empty(0, dtype=torch.int, device=gptq_w_q.device)
  173. marlin_w_q = ops.gptq_marlin_repack(
  174. gptq_w_q,
  175. sort_indices,
  176. self.input_size_per_partition,
  177. self.output_size_per_partition,
  178. 4,
  179. ).to(dev)
  180. marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0),
  181. self.input_size_per_partition,
  182. self.output_size_per_partition,
  183. self.quant_config.group_size).to(dev)
  184. marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0),
  185. self.input_size_per_partition,
  186. self.output_size_per_partition,
  187. self.quant_config.group_size).to(dev)
  188. layer.g_idx = marlin_make_empty_g_idx(dev)
  189. layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev)
  190. layer.marlin_qweight = marlin_w_q
  191. layer.marlin_zeros = marlin_zp
  192. layer.marlin_scales = marlin_s
  193. def apply(
  194. self,
  195. layer: torch.nn.Module,
  196. x: torch.Tensor,
  197. bias: Optional[torch.Tensor] = None,
  198. ) -> torch.Tensor:
  199. workspace = MarlinWorkspace(self.output_size_per_partition,
  200. GPTQ_MARLIN_MIN_THREAD_N,
  201. GPTQ_MARLIN_MAX_PARALLEL)
  202. scales = layer.marlin_scales
  203. zeros = layer.marlin_zeros
  204. orig_type = x.dtype
  205. if orig_type != torch.float16:
  206. x = x.to(torch.float16)
  207. scales = scales.to(torch.float16)
  208. zeros = zeros.to(torch.float16)
  209. marlin_out = ops.gptq_marlin_gemm(
  210. x,
  211. layer.marlin_qweight,
  212. scales,
  213. zeros,
  214. layer.g_idx,
  215. layer.g_idx_sort_indices,
  216. workspace.scratch,
  217. scalar_types.uint4,
  218. x.shape[0],
  219. self.output_size_per_partition,
  220. self.input_size_per_partition,
  221. True, # is_k_full
  222. True, # has_zp
  223. False, # use 32-bit reduce
  224. True, # use float zp
  225. )
  226. if bias is not None:
  227. marlin_out.add_(bias)
  228. if orig_type != torch.float16:
  229. return marlin_out.to(orig_type)
  230. else:
  231. return marlin_out