gptq_marlin.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn.parameter import Parameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  7. set_weight_attrs)
  8. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  9. from aphrodite.quantization.base_config import QuantizationConfig
  10. from aphrodite.quantization.utils.marlin_utils import (
  11. apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
  12. marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
  13. marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
  14. verify_marlin_supported, verify_marlin_supports_shape)
  15. from aphrodite.scalar_type import scalar_types
  16. class GPTQMarlinConfig(QuantizationConfig):
  17. """Config class for GPTQ Marlin"""
  18. # (num_bits, is_sym) -> quant_type
  19. TYPE_MAP = {
  20. (4, True): scalar_types.uint4b8,
  21. (8, True): scalar_types.uint8b128,
  22. }
  23. def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
  24. is_sym: bool, lm_head_quantized: bool) -> None:
  25. if desc_act and group_size == -1:
  26. # In this case, act_order == True is the same as act_order == False
  27. # (since we have only one group per output channel)
  28. desc_act = False
  29. self.pack_factor = 32 // weight_bits # packed into int32
  30. self.group_size = group_size
  31. self.desc_act = desc_act
  32. self.lm_head_quantized = lm_head_quantized
  33. if (weight_bits, is_sym) not in self.TYPE_MAP:
  34. raise ValueError("Unsupported quantization config: "
  35. f"bits={weight_bits}, sym={is_sym}")
  36. self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
  37. # Verify supported on platform.
  38. verify_marlin_supported(quant_type=self.quant_type,
  39. group_size=self.group_size)
  40. def __repr__(self) -> str:
  41. return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
  42. f"group_size={self.group_size}, "
  43. f"desc_act={self.desc_act}, "
  44. f"lm_head_quantized={self.lm_head_quantized})")
  45. @classmethod
  46. def get_name(cls) -> str:
  47. return "gptq_marlin"
  48. @classmethod
  49. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  50. return [torch.half, torch.bfloat16]
  51. @classmethod
  52. def get_min_capability(cls) -> int:
  53. return 80
  54. @classmethod
  55. def get_config_filenames(cls) -> List[str]:
  56. return ["quantize_config.json"]
  57. @classmethod
  58. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
  59. weight_bits = cls.get_from_keys(config, ["bits"])
  60. group_size = cls.get_from_keys(config, ["group_size"])
  61. desc_act = cls.get_from_keys(config, ["desc_act"])
  62. is_sym = cls.get_from_keys(config, ["sym"])
  63. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  64. default=False)
  65. return cls(weight_bits, group_size, desc_act, is_sym,
  66. lm_head_quantized)
  67. @classmethod
  68. def override_quantization_method(cls, hf_quant_cfg,
  69. user_quant) -> Optional[str]:
  70. can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
  71. is_valid_user_quant = (user_quant is None or user_quant == "marlin"
  72. or user_quant == "gptq_marlin")
  73. if can_convert and is_valid_user_quant:
  74. msg = ("The model is convertible to {} during runtime."
  75. " Using {} kernel.".format(cls.get_name(), cls.get_name()))
  76. logger.info(msg)
  77. return cls.get_name()
  78. if can_convert and user_quant == "gptq":
  79. logger.info("Detected that the model can run with gptq_marlin"
  80. ", however you specified quantization=gptq explicitly,"
  81. " so forcing gptq. Use quantization=gptq_marlin for"
  82. " faster inference")
  83. return None
  84. def get_quant_method(self, layer: torch.nn.Module,
  85. prefix: str) -> Optional["GPTQMarlinLinearMethod"]:
  86. if (isinstance(layer, LinearBase) or
  87. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  88. return GPTQMarlinLinearMethod(self)
  89. return None
  90. def get_scaled_act_names(self) -> List[str]:
  91. return []
  92. @classmethod
  93. def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
  94. # Extract data from quant config.
  95. quant_method = quant_config.get("quant_method", "").lower()
  96. num_bits = quant_config.get("bits", None)
  97. group_size = quant_config.get("group_size", None)
  98. sym = quant_config.get("sym", None)
  99. desc_act = quant_config.get("desc_act", None)
  100. if quant_method != "gptq":
  101. return False
  102. # If we cannot find the info needed in the config, cannot convert.
  103. if (num_bits is None or group_size is None or sym is None
  104. or desc_act is None):
  105. return False
  106. if (num_bits, sym) not in cls.TYPE_MAP:
  107. return False
  108. return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
  109. group_size=group_size)
  110. class GPTQMarlinLinearMethod(LinearMethodBase):
  111. """Linear method for GPTQ Marlin.
  112. Args:
  113. quant_config: The GPTQ Marlin quantization config.
  114. """
  115. def __init__(self, quant_config: GPTQMarlinConfig) -> None:
  116. self.quant_config = quant_config
  117. def create_weights(
  118. self,
  119. layer: torch.nn.Module,
  120. input_size_per_partition: int,
  121. output_partition_sizes: List[int],
  122. input_size: int,
  123. output_size: int,
  124. params_dtype: torch.dtype,
  125. **extra_weight_attrs,
  126. ) -> None:
  127. del output_size
  128. output_size_per_partition = sum(output_partition_sizes)
  129. is_row_parallel = input_size != input_size_per_partition
  130. # Normalize group_size
  131. if self.quant_config.group_size != -1:
  132. group_size = self.quant_config.group_size
  133. else:
  134. group_size = input_size
  135. verify_marlin_supports_shape(
  136. output_size_per_partition=output_size_per_partition,
  137. input_size_per_partition=input_size_per_partition,
  138. input_size=input_size,
  139. group_size=group_size)
  140. # Determine sharding
  141. if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
  142. self.quant_config.group_size,
  143. is_row_parallel):
  144. # By setting scale_dim == None, weight_loader will
  145. # repeat the scales on each GPU in TP>1 case.
  146. scales_and_zp_input_dim = None
  147. scales_and_zp_size = input_size // group_size
  148. else:
  149. # By setting scale_dim == 0, weight_loader will
  150. # shard the scales in TP>1 case.
  151. scales_and_zp_input_dim = 0
  152. scales_and_zp_size = input_size_per_partition // group_size
  153. # Quantized weights
  154. qweight = Parameter(
  155. torch.empty(
  156. input_size_per_partition // self.quant_config.pack_factor,
  157. output_size_per_partition,
  158. dtype=torch.int32,
  159. ),
  160. requires_grad=False,
  161. )
  162. set_weight_attrs(
  163. qweight,
  164. {
  165. **extra_weight_attrs,
  166. "input_dim": 0,
  167. "output_dim": 1,
  168. "packed_dim": 0,
  169. "pack_factor": self.quant_config.pack_factor,
  170. },
  171. )
  172. # Activation order
  173. g_idx = Parameter(
  174. torch.empty(
  175. input_size_per_partition,
  176. dtype=torch.int32,
  177. ),
  178. requires_grad=False,
  179. )
  180. # Ignore warning from fused linear layers such as QKVParallelLinear.
  181. set_weight_attrs(
  182. g_idx,
  183. {
  184. **extra_weight_attrs, "input_dim": 0,
  185. "ignore_warning": True
  186. },
  187. )
  188. # Scales
  189. scales = Parameter(
  190. torch.empty(
  191. scales_and_zp_size,
  192. output_size_per_partition,
  193. dtype=params_dtype,
  194. ),
  195. requires_grad=False,
  196. )
  197. set_weight_attrs(
  198. scales,
  199. {
  200. **extra_weight_attrs,
  201. "input_dim": scales_and_zp_input_dim,
  202. "output_dim": 1,
  203. },
  204. )
  205. # Quantized zero-points
  206. qzeros = Parameter(
  207. torch.empty(
  208. scales_and_zp_size,
  209. output_size_per_partition // self.quant_config.pack_factor,
  210. dtype=torch.int32,
  211. ),
  212. requires_grad=False,
  213. )
  214. set_weight_attrs(
  215. qzeros,
  216. {
  217. **extra_weight_attrs,
  218. "input_dim": scales_and_zp_input_dim,
  219. "output_dim": 1,
  220. "packed_dim": 1,
  221. "pack_factor": self.quant_config.pack_factor,
  222. },
  223. )
  224. layer.register_parameter("qweight", qweight)
  225. layer.register_parameter("g_idx", g_idx)
  226. layer.register_parameter("scales", scales)
  227. layer.register_parameter("qzeros", qzeros)
  228. layer.input_size_per_partition = input_size_per_partition
  229. layer.output_size_per_partition = output_size_per_partition
  230. layer.input_size = input_size
  231. layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
  232. is_row_parallel)
  233. # Checkpoints are serialized in AutoGPTQ format, which is different from the
  234. # marlin format. This function is called after the weights are loaded.
  235. # Here, we handle the repacking, including the activation reordering case.
  236. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  237. device = layer.qweight.device
  238. # Allocate marlin workspace
  239. layer.workspace = marlin_make_workspace(
  240. layer.output_size_per_partition, device)
  241. # Handle sorting for activation reordering if needed.
  242. if self.quant_config.desc_act:
  243. g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
  244. layer.g_idx_sort_indices = g_idx_sort_indices
  245. replace_tensor(layer, "g_idx", g_idx)
  246. else:
  247. layer.g_idx = marlin_make_empty_g_idx(device)
  248. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  249. # No zero-point
  250. layer.zp = marlin_make_empty_g_idx(device)
  251. # Repack weights from autogptq format to marlin format.
  252. marlin_qweight = ops.gptq_marlin_repack(
  253. layer.qweight,
  254. perm=layer.g_idx_sort_indices,
  255. size_k=layer.input_size_per_partition,
  256. size_n=layer.output_size_per_partition,
  257. num_bits=self.quant_config.quant_type.size_bits)
  258. replace_tensor(layer, "qweight", marlin_qweight)
  259. # Permute scales from autogptq format to marlin format.
  260. marlin_scales = marlin_permute_scales(
  261. layer.scales,
  262. size_k=(layer.input_size if self.quant_config.desc_act else
  263. layer.input_size_per_partition),
  264. size_n=layer.output_size_per_partition,
  265. group_size=self.quant_config.group_size)
  266. replace_tensor(layer, "scales", marlin_scales)
  267. def apply(
  268. self,
  269. layer: torch.nn.Module,
  270. x: torch.Tensor,
  271. bias: Optional[torch.Tensor] = None,
  272. ) -> torch.Tensor:
  273. return apply_gptq_marlin_linear(
  274. input=x,
  275. weight=layer.qweight,
  276. weight_scale=layer.scales,
  277. weight_zp=layer.zp,
  278. g_idx=layer.g_idx,
  279. g_idx_sort_indices=layer.g_idx_sort_indices,
  280. workspace=layer.workspace,
  281. wtype=self.quant_config.quant_type,
  282. output_size_per_partition=layer.output_size_per_partition,
  283. input_size_per_partition=layer.input_size_per_partition,
  284. is_k_full=layer.is_k_full,
  285. bias=bias)