gptq_marlin.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn import Parameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  8. from aphrodite.modeling.parameter import (ChannelQuantScaleParameter,
  9. GroupQuantScaleParameter,
  10. PackedAphroditeParameter,
  11. PackedColumnParameter,
  12. RowAphroditeParameter)
  13. from aphrodite.quantization.base_config import QuantizationConfig
  14. from aphrodite.quantization.utils.marlin_utils import (
  15. apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
  16. marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
  17. marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
  18. verify_marlin_supported, verify_marlin_supports_shape)
  19. from aphrodite.scalar_type import scalar_types
  20. from aphrodite.common.utils import is_hip
  21. class GPTQMarlinConfig(QuantizationConfig):
  22. """Config class for GPTQ Marlin"""
  23. # (num_bits, is_sym) -> quant_type
  24. TYPE_MAP = {
  25. (4, True): scalar_types.uint4b8,
  26. (8, True): scalar_types.uint8b128,
  27. }
  28. def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
  29. is_sym: bool, lm_head_quantized: bool) -> None:
  30. if desc_act and group_size == -1:
  31. # In this case, act_order == True is the same as act_order == False
  32. # (since we have only one group per output channel)
  33. desc_act = False
  34. self.pack_factor = 32 // weight_bits # packed into int32
  35. self.group_size = group_size
  36. self.desc_act = desc_act
  37. self.lm_head_quantized = lm_head_quantized
  38. if (weight_bits, is_sym) not in self.TYPE_MAP:
  39. raise ValueError("Unsupported quantization config: "
  40. f"bits={weight_bits}, sym={is_sym}")
  41. self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
  42. # Verify supported on platform.
  43. verify_marlin_supported(quant_type=self.quant_type,
  44. group_size=self.group_size)
  45. def __repr__(self) -> str:
  46. return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
  47. f"group_size={self.group_size}, "
  48. f"desc_act={self.desc_act}, "
  49. f"lm_head_quantized={self.lm_head_quantized})")
  50. @classmethod
  51. def get_name(cls) -> str:
  52. return "gptq_marlin"
  53. @classmethod
  54. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  55. return [torch.half, torch.bfloat16]
  56. @classmethod
  57. def get_min_capability(cls) -> int:
  58. return 80
  59. @classmethod
  60. def get_config_filenames(cls) -> List[str]:
  61. return ["quantize_config.json"]
  62. @classmethod
  63. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
  64. weight_bits = cls.get_from_keys(config, ["bits"])
  65. group_size = cls.get_from_keys(config, ["group_size"])
  66. desc_act = cls.get_from_keys(config, ["desc_act"])
  67. is_sym = cls.get_from_keys(config, ["sym"])
  68. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  69. default=False)
  70. return cls(weight_bits, group_size, desc_act, is_sym,
  71. lm_head_quantized)
  72. @classmethod
  73. def override_quantization_method(cls, hf_quant_cfg,
  74. user_quant) -> Optional[str]:
  75. can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
  76. is_valid_user_quant = (user_quant is None or user_quant == "marlin"
  77. or user_quant == "gptq_marlin")
  78. if is_hip():
  79. return None
  80. if can_convert and is_valid_user_quant:
  81. msg = ("The model is convertible to {} during runtime."
  82. " Using {} kernel.".format(cls.get_name(), cls.get_name()))
  83. logger.info(msg)
  84. return cls.get_name()
  85. if can_convert and user_quant == "gptq":
  86. logger.info("Detected that the model can run with gptq_marlin"
  87. ", however you specified quantization=gptq explicitly,"
  88. " so forcing gptq. Use quantization=gptq_marlin for"
  89. " faster inference")
  90. return None
  91. def get_quant_method(self, layer: torch.nn.Module,
  92. prefix: str) -> Optional["GPTQMarlinLinearMethod"]:
  93. if (isinstance(layer, LinearBase) or
  94. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  95. return GPTQMarlinLinearMethod(self)
  96. return None
  97. def get_scaled_act_names(self) -> List[str]:
  98. return []
  99. @classmethod
  100. def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
  101. # Extract data from quant config.
  102. quant_method = quant_config.get("quant_method", "").lower()
  103. num_bits = quant_config.get("bits", None)
  104. group_size = quant_config.get("group_size", None)
  105. sym = quant_config.get("sym", None)
  106. desc_act = quant_config.get("desc_act", None)
  107. if quant_method != "gptq":
  108. return False
  109. # If we cannot find the info needed in the config, cannot convert.
  110. if (num_bits is None or group_size is None or sym is None
  111. or desc_act is None):
  112. return False
  113. if (num_bits, sym) not in cls.TYPE_MAP:
  114. return False
  115. return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
  116. group_size=group_size)
  117. class GPTQMarlinLinearMethod(LinearMethodBase):
  118. """Linear method for GPTQ Marlin.
  119. Args:
  120. quant_config: The GPTQ Marlin quantization config.
  121. """
  122. def __init__(self, quant_config: GPTQMarlinConfig) -> None:
  123. self.quant_config = quant_config
  124. def create_weights(
  125. self,
  126. layer: torch.nn.Module,
  127. input_size_per_partition: int,
  128. output_partition_sizes: List[int],
  129. input_size: int,
  130. output_size: int,
  131. params_dtype: torch.dtype,
  132. **extra_weight_attrs,
  133. ) -> None:
  134. del output_size
  135. output_size_per_partition = sum(output_partition_sizes)
  136. is_row_parallel = input_size != input_size_per_partition
  137. weight_loader = extra_weight_attrs.get("weight_loader")
  138. # Normalize group_size
  139. if self.quant_config.group_size != -1:
  140. group_size = self.quant_config.group_size
  141. else:
  142. group_size = input_size
  143. verify_marlin_supports_shape(
  144. output_size_per_partition=output_size_per_partition,
  145. input_size_per_partition=input_size_per_partition,
  146. input_size=input_size,
  147. group_size=group_size)
  148. # Determine sharding
  149. if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
  150. self.quant_config.group_size,
  151. is_row_parallel):
  152. # By setting scale_dim == None, weight_loader will
  153. # repeat the scales on each GPU in TP>1 case.
  154. scales_and_zp_input_dim = None
  155. scales_and_zp_size = input_size // group_size
  156. else:
  157. # By setting scale_dim == 0, weight_loader will
  158. # shard the scales in TP>1 case.
  159. scales_and_zp_input_dim = 0
  160. scales_and_zp_size = input_size_per_partition // group_size
  161. # Quantized weights
  162. qweight = PackedAphroditeParameter(
  163. data=torch.empty(
  164. input_size_per_partition // self.quant_config.pack_factor,
  165. output_size_per_partition,
  166. dtype=torch.int32,
  167. ),
  168. input_dim=0,
  169. output_dim=1,
  170. packed_dim=0,
  171. packed_factor=self.quant_config.pack_factor,
  172. weight_loader=weight_loader)
  173. # Activation order
  174. g_idx = RowAphroditeParameter(data=torch.empty(
  175. input_size_per_partition,
  176. dtype=torch.int32,
  177. ),
  178. input_dim=0,
  179. weight_loader=weight_loader)
  180. qzeros_args = {
  181. "data":
  182. torch.empty(
  183. scales_and_zp_size,
  184. output_size_per_partition // self.quant_config.pack_factor,
  185. dtype=torch.int32,
  186. ),
  187. "weight_loader":
  188. weight_loader
  189. }
  190. weight_scale_args = {
  191. "data":
  192. torch.empty(
  193. scales_and_zp_size,
  194. output_size_per_partition,
  195. dtype=params_dtype,
  196. ),
  197. "weight_loader":
  198. weight_loader
  199. }
  200. if scales_and_zp_input_dim is None:
  201. scales = ChannelQuantScaleParameter(output_dim=1,
  202. **weight_scale_args)
  203. qzeros = PackedColumnParameter(
  204. output_dim=1,
  205. packed_dim=1,
  206. packed_factor=self.quant_config.pack_factor,
  207. **qzeros_args)
  208. else:
  209. scales = GroupQuantScaleParameter(output_dim=1,
  210. input_dim=0,
  211. **weight_scale_args)
  212. qzeros = PackedAphroditeParameter(
  213. input_dim=0,
  214. output_dim=1,
  215. packed_dim=1,
  216. packed_factor=self.quant_config.pack_factor,
  217. **qzeros_args)
  218. layer.register_parameter("qweight", qweight)
  219. layer.register_parameter("g_idx", g_idx)
  220. layer.register_parameter("scales", scales)
  221. layer.register_parameter("qzeros", qzeros)
  222. layer.input_size_per_partition = input_size_per_partition
  223. layer.output_size_per_partition = output_size_per_partition
  224. layer.input_size = input_size
  225. layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
  226. is_row_parallel)
  227. # Checkpoints are serialized in AutoGPTQ format, which is different from the
  228. # marlin format. This function is called after the weights are loaded.
  229. # Here, we handle the repacking, including the activation reordering case.
  230. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  231. device = layer.qweight.device
  232. # required by torch.compile
  233. layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
  234. layer.scales = Parameter(layer.scales.data, requires_grad=False)
  235. # Allocate marlin workspace
  236. layer.workspace = marlin_make_workspace(
  237. layer.output_size_per_partition, device)
  238. # Handle sorting for activation reordering if needed.
  239. if self.quant_config.desc_act:
  240. g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
  241. layer.g_idx_sort_indices = g_idx_sort_indices
  242. replace_tensor(layer, "g_idx", g_idx)
  243. else:
  244. layer.g_idx = marlin_make_empty_g_idx(device)
  245. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  246. # No zero-point
  247. layer.zp = marlin_make_empty_g_idx(device)
  248. # Repack weights from autogptq format to marlin format.
  249. marlin_qweight = ops.gptq_marlin_repack(
  250. layer.qweight,
  251. perm=layer.g_idx_sort_indices,
  252. size_k=layer.input_size_per_partition,
  253. size_n=layer.output_size_per_partition,
  254. num_bits=self.quant_config.quant_type.size_bits)
  255. replace_tensor(layer, "qweight", marlin_qweight)
  256. # Permute scales from autogptq format to marlin format.
  257. marlin_scales = marlin_permute_scales(
  258. layer.scales,
  259. size_k=(layer.input_size if self.quant_config.desc_act else
  260. layer.input_size_per_partition),
  261. size_n=layer.output_size_per_partition,
  262. group_size=self.quant_config.group_size)
  263. replace_tensor(layer, "scales", marlin_scales)
  264. def apply(
  265. self,
  266. layer: torch.nn.Module,
  267. x: torch.Tensor,
  268. bias: Optional[torch.Tensor] = None,
  269. ) -> torch.Tensor:
  270. return apply_gptq_marlin_linear(
  271. input=x,
  272. weight=layer.qweight,
  273. weight_scale=layer.scales,
  274. weight_zp=layer.zp,
  275. g_idx=layer.g_idx,
  276. g_idx_sort_indices=layer.g_idx_sort_indices,
  277. workspace=layer.workspace,
  278. wtype=self.quant_config.quant_type,
  279. output_size_per_partition=layer.output_size_per_partition,
  280. input_size_per_partition=layer.input_size_per_partition,
  281. is_k_full=layer.is_k_full,
  282. bias=bias)