gptq_marlin.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn.parameter import Parameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  7. set_weight_attrs)
  8. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  9. from aphrodite.quantization.base_config import QuantizationConfig
  10. from aphrodite.quantization.utils.marlin_utils import (
  11. apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
  12. marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
  13. marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
  14. verify_marlin_supported, verify_marlin_supports_shape)
  15. from aphrodite.scalar_type import scalar_types
  16. class GPTQMarlinConfig(QuantizationConfig):
  17. """Config class for GPTQ Marlin"""
  18. # (num_bits, is_sym) -> quant_type
  19. TYPE_MAP = {
  20. (4, True): scalar_types.uint4b8,
  21. (8, True): scalar_types.uint8b128,
  22. }
  23. def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
  24. is_sym: bool, lm_head_quantized: bool) -> None:
  25. if desc_act and group_size == -1:
  26. # In this case, act_order == True is the same as act_order == False
  27. # (since we have only one group per output channel)
  28. desc_act = False
  29. self.pack_factor = 32 // weight_bits # packed into int32
  30. self.group_size = group_size
  31. self.desc_act = desc_act
  32. self.lm_head_quantized = lm_head_quantized
  33. if (weight_bits, is_sym) not in self.TYPE_MAP:
  34. raise ValueError("Unsupported quantization config: "
  35. f"bits={weight_bits}, sym={is_sym}")
  36. self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
  37. # Verify supported on platform.
  38. verify_marlin_supported(quant_type=self.quant_type,
  39. group_size=self.group_size)
  40. def __repr__(self) -> str:
  41. return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
  42. f"group_size={self.group_size}, "
  43. f"desc_act={self.desc_act}, "
  44. f"lm_head_quantized={self.lm_head_quantized})")
  45. @classmethod
  46. def get_name(cls) -> str:
  47. return "gptq_marlin"
  48. @classmethod
  49. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  50. return [torch.half, torch.bfloat16]
  51. @classmethod
  52. def get_min_capability(cls) -> int:
  53. return 80
  54. @classmethod
  55. def get_config_filenames(cls) -> List[str]:
  56. return ["quantize_config.json"]
  57. @classmethod
  58. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
  59. weight_bits = cls.get_from_keys(config, ["bits"])
  60. group_size = cls.get_from_keys(config, ["group_size"])
  61. desc_act = cls.get_from_keys(config, ["desc_act"])
  62. is_sym = cls.get_from_keys(config, ["sym"])
  63. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  64. default=False)
  65. return cls(weight_bits, group_size, desc_act, is_sym,
  66. lm_head_quantized)
  67. @classmethod
  68. def override_quantization_method(cls, hf_quant_cfg,
  69. user_quant) -> Optional[str]:
  70. can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
  71. is_valid_user_quant = (user_quant is None or user_quant == "marlin"
  72. or user_quant == "gptq_marlin")
  73. if can_convert and is_valid_user_quant:
  74. msg = ("The model is convertible to {} during runtime."
  75. " Using {} kernel.".format(cls.get_name(), cls.get_name()))
  76. logger.info(msg)
  77. return cls.get_name()
  78. if can_convert and user_quant == "gptq":
  79. logger.info("Detected that the model can run with gptq_marlin"
  80. ", however you specified quantization=gptq explicitly,"
  81. " so forcing gptq. Use quantization=gptq_marlin for"
  82. " faster inference")
  83. return None
  84. def get_quant_method(self, layer: torch.nn.Module,
  85. prefix: str) -> Optional["GPTQMarlinLinearMethod"]:
  86. if (isinstance(layer, LinearBase) or
  87. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  88. return GPTQMarlinLinearMethod(self)
  89. return None
  90. def get_scaled_act_names(self) -> List[str]:
  91. return []
  92. @classmethod
  93. def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
  94. # Extract data from quant config.
  95. quant_method = quant_config.get("quant_method", "").lower()
  96. num_bits = quant_config.get("bits", None)
  97. group_size = quant_config.get("group_size", None)
  98. sym = quant_config.get("sym", None)
  99. desc_act = quant_config.get("desc_act", None)
  100. if quant_method != "gptq":
  101. return False
  102. # If we cannot find the info needed in the config, cannot convert.
  103. if (num_bits is None or group_size is None or sym is None
  104. or desc_act is None):
  105. return False
  106. if (num_bits, sym) not in cls.TYPE_MAP:
  107. return False
  108. return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
  109. group_size=group_size,
  110. min_capability=cls.get_min_capability())
  111. class GPTQMarlinLinearMethod(LinearMethodBase):
  112. """Linear method for GPTQ Marlin.
  113. Args:
  114. quant_config: The GPTQ Marlin quantization config.
  115. """
  116. def __init__(self, quant_config: GPTQMarlinConfig) -> None:
  117. self.quant_config = quant_config
  118. def create_weights(
  119. self,
  120. layer: torch.nn.Module,
  121. input_size_per_partition: int,
  122. output_partition_sizes: List[int],
  123. input_size: int,
  124. output_size: int,
  125. params_dtype: torch.dtype,
  126. **extra_weight_attrs,
  127. ) -> None:
  128. del output_size
  129. output_size_per_partition = sum(output_partition_sizes)
  130. is_row_parallel = input_size != input_size_per_partition
  131. # Normalize group_size
  132. if self.quant_config.group_size != -1:
  133. group_size = self.quant_config.group_size
  134. else:
  135. group_size = input_size
  136. verify_marlin_supports_shape(
  137. output_size_per_partition=output_size_per_partition,
  138. input_size_per_partition=input_size_per_partition,
  139. input_size=input_size,
  140. group_size=group_size)
  141. # Determine sharding
  142. if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
  143. self.quant_config.group_size,
  144. is_row_parallel):
  145. # By setting scale_dim == None, weight_loader will
  146. # repeat the scales on each GPU in TP>1 case.
  147. scales_and_zp_input_dim = None
  148. scales_and_zp_size = input_size // group_size
  149. else:
  150. # By setting scale_dim == 0, weight_loader will
  151. # shard the scales in TP>1 case.
  152. scales_and_zp_input_dim = 0
  153. scales_and_zp_size = input_size_per_partition // group_size
  154. # Quantized weights
  155. qweight = Parameter(
  156. torch.empty(
  157. input_size_per_partition // self.quant_config.pack_factor,
  158. output_size_per_partition,
  159. dtype=torch.int32,
  160. ),
  161. requires_grad=False,
  162. )
  163. set_weight_attrs(
  164. qweight,
  165. {
  166. **extra_weight_attrs,
  167. "input_dim": 0,
  168. "output_dim": 1,
  169. "packed_dim": 0,
  170. "pack_factor": self.quant_config.pack_factor,
  171. },
  172. )
  173. # Activation order
  174. g_idx = Parameter(
  175. torch.empty(
  176. input_size_per_partition,
  177. dtype=torch.int32,
  178. ),
  179. requires_grad=False,
  180. )
  181. # Ignore warning from fused linear layers such as QKVParallelLinear.
  182. set_weight_attrs(
  183. g_idx,
  184. {
  185. **extra_weight_attrs, "input_dim": 0,
  186. "ignore_warning": True
  187. },
  188. )
  189. # Scales
  190. scales = Parameter(
  191. torch.empty(
  192. scales_and_zp_size,
  193. output_size_per_partition,
  194. dtype=params_dtype,
  195. ),
  196. requires_grad=False,
  197. )
  198. set_weight_attrs(
  199. scales,
  200. {
  201. **extra_weight_attrs,
  202. "input_dim": scales_and_zp_input_dim,
  203. "output_dim": 1,
  204. },
  205. )
  206. # Quantized zero-points
  207. qzeros = Parameter(
  208. torch.empty(
  209. scales_and_zp_size,
  210. output_size_per_partition // self.quant_config.pack_factor,
  211. dtype=torch.int32,
  212. ),
  213. requires_grad=False,
  214. )
  215. set_weight_attrs(
  216. qzeros,
  217. {
  218. **extra_weight_attrs,
  219. "input_dim": scales_and_zp_input_dim,
  220. "output_dim": 1,
  221. "packed_dim": 1,
  222. "pack_factor": self.quant_config.pack_factor,
  223. },
  224. )
  225. layer.register_parameter("qweight", qweight)
  226. layer.register_parameter("g_idx", g_idx)
  227. layer.register_parameter("scales", scales)
  228. layer.register_parameter("qzeros", qzeros)
  229. layer.input_size_per_partition = input_size_per_partition
  230. layer.output_size_per_partition = output_size_per_partition
  231. layer.input_size = input_size
  232. layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
  233. is_row_parallel)
  234. # Checkpoints are serialized in AutoGPTQ format, which is different from the
  235. # marlin format. This function is called after the weights are loaded.
  236. # Here, we handle the repacking, including the activation reordering case.
  237. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  238. device = layer.qweight.device
  239. # Allocate marlin workspace
  240. layer.workspace = marlin_make_workspace(
  241. layer.output_size_per_partition, device)
  242. # Handle sorting for activation reordering if needed.
  243. if self.quant_config.desc_act:
  244. g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
  245. layer.g_idx_sort_indices = g_idx_sort_indices
  246. replace_tensor(layer, "g_idx", g_idx)
  247. else:
  248. layer.g_idx = marlin_make_empty_g_idx(device)
  249. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  250. # No zero-point
  251. layer.zp = marlin_make_empty_g_idx(device)
  252. # Repack weights from autogptq format to marlin format.
  253. marlin_qweight = ops.gptq_marlin_repack(
  254. layer.qweight,
  255. perm=layer.g_idx_sort_indices,
  256. size_k=layer.input_size_per_partition,
  257. size_n=layer.output_size_per_partition,
  258. num_bits=self.quant_config.quant_type.size_bits)
  259. replace_tensor(layer, "qweight", marlin_qweight)
  260. # Permute scales from autogptq format to marlin format.
  261. marlin_scales = marlin_permute_scales(
  262. layer.scales,
  263. size_k=(layer.input_size if self.quant_config.desc_act else
  264. layer.input_size_per_partition),
  265. size_n=layer.output_size_per_partition,
  266. group_size=self.quant_config.group_size)
  267. replace_tensor(layer, "scales", marlin_scales)
  268. def apply(
  269. self,
  270. layer: torch.nn.Module,
  271. x: torch.Tensor,
  272. bias: Optional[torch.Tensor] = None,
  273. ) -> torch.Tensor:
  274. return apply_gptq_marlin_linear(
  275. input=x,
  276. weight=layer.qweight,
  277. weight_scale=layer.scales,
  278. weight_zp=layer.zp,
  279. g_idx=layer.g_idx,
  280. g_idx_sort_indices=layer.g_idx_sort_indices,
  281. workspace=layer.workspace,
  282. wtype=self.quant_config.quant_type,
  283. output_size_per_partition=layer.output_size_per_partition,
  284. input_size_per_partition=layer.input_size_per_partition,
  285. is_k_full=layer.is_k_full,
  286. bias=bias)