gptq_marlin.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn import Parameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  8. from aphrodite.modeling.parameter import (ChannelQuantScaleParameter,
  9. GroupQuantScaleParameter,
  10. PackedAphroditeParameter,
  11. PackedColumnParameter,
  12. RowAphroditeParameter)
  13. from aphrodite.quantization.base_config import QuantizationConfig
  14. from aphrodite.quantization.utils.marlin_utils import (
  15. apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
  16. marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
  17. marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
  18. verify_marlin_supported, verify_marlin_supports_shape)
  19. from aphrodite.scalar_type import scalar_types
  20. class GPTQMarlinConfig(QuantizationConfig):
  21. """Config class for GPTQ Marlin"""
  22. # (num_bits, is_sym) -> quant_type
  23. TYPE_MAP = {
  24. (4, True): scalar_types.uint4b8,
  25. (8, True): scalar_types.uint8b128,
  26. }
  27. def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
  28. is_sym: bool, lm_head_quantized: bool) -> None:
  29. if desc_act and group_size == -1:
  30. # In this case, act_order == True is the same as act_order == False
  31. # (since we have only one group per output channel)
  32. desc_act = False
  33. self.pack_factor = 32 // weight_bits # packed into int32
  34. self.group_size = group_size
  35. self.desc_act = desc_act
  36. self.lm_head_quantized = lm_head_quantized
  37. if (weight_bits, is_sym) not in self.TYPE_MAP:
  38. raise ValueError("Unsupported quantization config: "
  39. f"bits={weight_bits}, sym={is_sym}")
  40. self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
  41. # Verify supported on platform.
  42. verify_marlin_supported(quant_type=self.quant_type,
  43. group_size=self.group_size)
  44. def __repr__(self) -> str:
  45. return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
  46. f"group_size={self.group_size}, "
  47. f"desc_act={self.desc_act}, "
  48. f"lm_head_quantized={self.lm_head_quantized})")
  49. @classmethod
  50. def get_name(cls) -> str:
  51. return "gptq_marlin"
  52. @classmethod
  53. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  54. return [torch.half, torch.bfloat16]
  55. @classmethod
  56. def get_min_capability(cls) -> int:
  57. return 80
  58. @classmethod
  59. def get_config_filenames(cls) -> List[str]:
  60. return ["quantize_config.json"]
  61. @classmethod
  62. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
  63. weight_bits = cls.get_from_keys(config, ["bits"])
  64. group_size = cls.get_from_keys(config, ["group_size"])
  65. desc_act = cls.get_from_keys(config, ["desc_act"])
  66. is_sym = cls.get_from_keys(config, ["sym"])
  67. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  68. default=False)
  69. return cls(weight_bits, group_size, desc_act, is_sym,
  70. lm_head_quantized)
  71. @classmethod
  72. def override_quantization_method(cls, hf_quant_cfg,
  73. user_quant) -> Optional[str]:
  74. can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
  75. is_valid_user_quant = (user_quant is None or user_quant == "marlin"
  76. or user_quant == "gptq_marlin")
  77. if can_convert and is_valid_user_quant:
  78. msg = ("The model is convertible to {} during runtime."
  79. " Using {} kernel.".format(cls.get_name(), cls.get_name()))
  80. logger.info(msg)
  81. return cls.get_name()
  82. if can_convert and user_quant == "gptq":
  83. logger.info("Detected that the model can run with gptq_marlin"
  84. ", however you specified quantization=gptq explicitly,"
  85. " so forcing gptq. Use quantization=gptq_marlin for"
  86. " faster inference")
  87. return None
  88. def get_quant_method(self, layer: torch.nn.Module,
  89. prefix: str) -> Optional["GPTQMarlinLinearMethod"]:
  90. if (isinstance(layer, LinearBase) or
  91. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  92. return GPTQMarlinLinearMethod(self)
  93. return None
  94. def get_scaled_act_names(self) -> List[str]:
  95. return []
  96. @classmethod
  97. def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
  98. # Extract data from quant config.
  99. quant_method = quant_config.get("quant_method", "").lower()
  100. num_bits = quant_config.get("bits", None)
  101. group_size = quant_config.get("group_size", None)
  102. sym = quant_config.get("sym", None)
  103. desc_act = quant_config.get("desc_act", None)
  104. if quant_method != "gptq":
  105. return False
  106. # If we cannot find the info needed in the config, cannot convert.
  107. if (num_bits is None or group_size is None or sym is None
  108. or desc_act is None):
  109. return False
  110. if (num_bits, sym) not in cls.TYPE_MAP:
  111. return False
  112. return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
  113. group_size=group_size)
  114. class GPTQMarlinLinearMethod(LinearMethodBase):
  115. """Linear method for GPTQ Marlin.
  116. Args:
  117. quant_config: The GPTQ Marlin quantization config.
  118. """
  119. def __init__(self, quant_config: GPTQMarlinConfig) -> None:
  120. self.quant_config = quant_config
  121. def create_weights(
  122. self,
  123. layer: torch.nn.Module,
  124. input_size_per_partition: int,
  125. output_partition_sizes: List[int],
  126. input_size: int,
  127. output_size: int,
  128. params_dtype: torch.dtype,
  129. **extra_weight_attrs,
  130. ) -> None:
  131. del output_size
  132. output_size_per_partition = sum(output_partition_sizes)
  133. is_row_parallel = input_size != input_size_per_partition
  134. weight_loader = extra_weight_attrs.get("weight_loader")
  135. # Normalize group_size
  136. if self.quant_config.group_size != -1:
  137. group_size = self.quant_config.group_size
  138. else:
  139. group_size = input_size
  140. verify_marlin_supports_shape(
  141. output_size_per_partition=output_size_per_partition,
  142. input_size_per_partition=input_size_per_partition,
  143. input_size=input_size,
  144. group_size=group_size)
  145. # Determine sharding
  146. if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
  147. self.quant_config.group_size,
  148. is_row_parallel):
  149. # By setting scale_dim == None, weight_loader will
  150. # repeat the scales on each GPU in TP>1 case.
  151. scales_and_zp_input_dim = None
  152. scales_and_zp_size = input_size // group_size
  153. else:
  154. # By setting scale_dim == 0, weight_loader will
  155. # shard the scales in TP>1 case.
  156. scales_and_zp_input_dim = 0
  157. scales_and_zp_size = input_size_per_partition // group_size
  158. # Quantized weights
  159. qweight = PackedAphroditeParameter(
  160. data=torch.empty(
  161. input_size_per_partition // self.quant_config.pack_factor,
  162. output_size_per_partition,
  163. dtype=torch.int32,
  164. ),
  165. input_dim=0,
  166. output_dim=1,
  167. packed_dim=0,
  168. packed_factor=self.quant_config.pack_factor,
  169. weight_loader=weight_loader)
  170. # Activation order
  171. g_idx = RowAphroditeParameter(data=torch.empty(
  172. input_size_per_partition,
  173. dtype=torch.int32,
  174. ),
  175. input_dim=0,
  176. weight_loader=weight_loader)
  177. qzeros_args = {
  178. "data":
  179. torch.empty(
  180. scales_and_zp_size,
  181. output_size_per_partition // self.quant_config.pack_factor,
  182. dtype=torch.int32,
  183. ),
  184. "weight_loader":
  185. weight_loader
  186. }
  187. weight_scale_args = {
  188. "data":
  189. torch.empty(
  190. scales_and_zp_size,
  191. output_size_per_partition,
  192. dtype=params_dtype,
  193. ),
  194. "weight_loader":
  195. weight_loader
  196. }
  197. if scales_and_zp_input_dim is None:
  198. scales = ChannelQuantScaleParameter(output_dim=1,
  199. **weight_scale_args)
  200. qzeros = PackedColumnParameter(
  201. output_dim=1,
  202. packed_dim=1,
  203. packed_factor=self.quant_config.pack_factor,
  204. **qzeros_args)
  205. else:
  206. scales = GroupQuantScaleParameter(output_dim=1,
  207. input_dim=0,
  208. **weight_scale_args)
  209. qzeros = PackedAphroditeParameter(
  210. input_dim=0,
  211. output_dim=1,
  212. packed_dim=1,
  213. packed_factor=self.quant_config.pack_factor,
  214. **qzeros_args)
  215. layer.register_parameter("qweight", qweight)
  216. layer.register_parameter("g_idx", g_idx)
  217. layer.register_parameter("scales", scales)
  218. layer.register_parameter("qzeros", qzeros)
  219. layer.input_size_per_partition = input_size_per_partition
  220. layer.output_size_per_partition = output_size_per_partition
  221. layer.input_size = input_size
  222. layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
  223. is_row_parallel)
  224. # Checkpoints are serialized in AutoGPTQ format, which is different from the
  225. # marlin format. This function is called after the weights are loaded.
  226. # Here, we handle the repacking, including the activation reordering case.
  227. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  228. device = layer.qweight.device
  229. # required by torch.compile
  230. layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
  231. layer.scales = Parameter(layer.scales.data, requires_grad=False)
  232. # Allocate marlin workspace
  233. layer.workspace = marlin_make_workspace(
  234. layer.output_size_per_partition, device)
  235. # Handle sorting for activation reordering if needed.
  236. if self.quant_config.desc_act:
  237. g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
  238. layer.g_idx_sort_indices = g_idx_sort_indices
  239. replace_tensor(layer, "g_idx", g_idx)
  240. else:
  241. layer.g_idx = marlin_make_empty_g_idx(device)
  242. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  243. # No zero-point
  244. layer.zp = marlin_make_empty_g_idx(device)
  245. # Repack weights from autogptq format to marlin format.
  246. marlin_qweight = ops.gptq_marlin_repack(
  247. layer.qweight,
  248. perm=layer.g_idx_sort_indices,
  249. size_k=layer.input_size_per_partition,
  250. size_n=layer.output_size_per_partition,
  251. num_bits=self.quant_config.quant_type.size_bits)
  252. replace_tensor(layer, "qweight", marlin_qweight)
  253. # Permute scales from autogptq format to marlin format.
  254. marlin_scales = marlin_permute_scales(
  255. layer.scales,
  256. size_k=(layer.input_size if self.quant_config.desc_act else
  257. layer.input_size_per_partition),
  258. size_n=layer.output_size_per_partition,
  259. group_size=self.quant_config.group_size)
  260. replace_tensor(layer, "scales", marlin_scales)
  261. def apply(
  262. self,
  263. layer: torch.nn.Module,
  264. x: torch.Tensor,
  265. bias: Optional[torch.Tensor] = None,
  266. ) -> torch.Tensor:
  267. return apply_gptq_marlin_linear(
  268. input=x,
  269. weight=layer.qweight,
  270. weight_scale=layer.scales,
  271. weight_zp=layer.zp,
  272. g_idx=layer.g_idx,
  273. g_idx_sort_indices=layer.g_idx_sort_indices,
  274. workspace=layer.workspace,
  275. wtype=self.quant_config.quant_type,
  276. output_size_per_partition=layer.output_size_per_partition,
  277. input_size_per_partition=layer.input_size_per_partition,
  278. is_k_full=layer.is_k_full,
  279. bias=bias)