gptq_marlin.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn.parameter import Parameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  7. set_weight_attrs)
  8. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  9. from aphrodite.quantization.base_config import QuantizationConfig
  10. from aphrodite.quantization.utils.marlin_utils import (
  11. check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
  12. marlin_permute_scales, marlin_sort_g_idx, replace_tensor,
  13. verify_marlin_supported, verify_marlin_supports_shape)
  14. class GPTQMarlinConfig(QuantizationConfig):
  15. """Config class for GPTQ Marlin"""
  16. def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
  17. is_sym: bool, lm_head_quantized: bool) -> None:
  18. if desc_act and group_size == -1:
  19. # In this case, act_order == True is the same as act_order == False
  20. # (since we have only one group per output channel)
  21. desc_act = False
  22. self.weight_bits = weight_bits
  23. self.pack_factor = 32 // self.weight_bits # packed into int32
  24. self.group_size = group_size
  25. self.desc_act = desc_act
  26. self.is_sym = is_sym
  27. self.lm_head_quantized = lm_head_quantized
  28. # Verify supported on platform.
  29. verify_marlin_supported(num_bits=self.weight_bits,
  30. group_size=self.group_size,
  31. is_sym=self.is_sym)
  32. def __repr__(self) -> str:
  33. return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
  34. f"group_size={self.group_size}, "
  35. f"desc_act={self.desc_act}, "
  36. f"lm_head_quantized={self.lm_head_quantized})")
  37. @classmethod
  38. def get_name(cls) -> str:
  39. return "gptq_marlin"
  40. @classmethod
  41. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  42. return [torch.half, torch.bfloat16]
  43. @classmethod
  44. def get_min_capability(cls) -> int:
  45. return 80
  46. @classmethod
  47. def get_config_filenames(cls) -> List[str]:
  48. return ["quantize_config.json"]
  49. @classmethod
  50. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
  51. weight_bits = cls.get_from_keys(config, ["bits"])
  52. group_size = cls.get_from_keys(config, ["group_size"])
  53. desc_act = cls.get_from_keys(config, ["desc_act"])
  54. is_sym = cls.get_from_keys(config, ["sym"])
  55. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  56. default=False)
  57. return cls(weight_bits, group_size, desc_act, is_sym,
  58. lm_head_quantized)
  59. @classmethod
  60. def override_quantization_method(cls, hf_quant_cfg,
  61. user_quant) -> Optional[str]:
  62. can_convert = cls.is_marlin_compatible(hf_quant_cfg)
  63. is_valid_user_quant = (user_quant is None or user_quant == "marlin")
  64. if can_convert and is_valid_user_quant:
  65. msg = ("The model is convertible to {} during runtime."
  66. " Using {} kernel.".format(cls.get_name(), cls.get_name()))
  67. logger.info(msg)
  68. return cls.get_name()
  69. if can_convert and user_quant == "gptq":
  70. logger.info("Detected that the model can run with gptq_marlin"
  71. ", however you specified quantization=gptq explicitly,"
  72. " so forcing gptq. Use quantization=gptq_marlin for"
  73. " faster inference")
  74. return None
  75. def get_quant_method(
  76. self,
  77. layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
  78. if (isinstance(layer, LinearBase) or
  79. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  80. return GPTQMarlinLinearMethod(self)
  81. return None
  82. def get_scaled_act_names(self) -> List[str]:
  83. return []
  84. @classmethod
  85. def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
  86. # Extract data from quant config.
  87. num_bits = quant_config.get("bits", None)
  88. group_size = quant_config.get("group_size", None)
  89. sym = quant_config.get("sym", None)
  90. desc_act = quant_config.get("desc_act", None)
  91. # If we cannot find the info needed in the config, cannot convert.
  92. if (num_bits is None or group_size is None or sym is None
  93. or desc_act is None):
  94. return False
  95. return check_marlin_supported(num_bits=num_bits,
  96. group_size=group_size,
  97. is_sym=sym,
  98. min_capability=cls.get_min_capability())
  99. class GPTQMarlinLinearMethod(LinearMethodBase):
  100. """Linear method for GPTQ Marlin.
  101. Args:
  102. quant_config: The GPTQ Marlin quantization config.
  103. """
  104. def __init__(self, quant_config: GPTQMarlinConfig) -> None:
  105. self.quant_config = quant_config
  106. def create_weights(
  107. self,
  108. layer: torch.nn.Module,
  109. input_size_per_partition: int,
  110. output_partition_sizes: List[int],
  111. input_size: int,
  112. output_size: int,
  113. params_dtype: torch.dtype,
  114. **extra_weight_attrs,
  115. ) -> None:
  116. del output_size
  117. output_size_per_partition = sum(output_partition_sizes)
  118. # Normalize group_size
  119. if self.quant_config.group_size != -1:
  120. group_size = self.quant_config.group_size
  121. else:
  122. group_size = input_size
  123. verify_marlin_supports_shape(
  124. output_size_per_partition=output_size_per_partition,
  125. input_size_per_partition=input_size_per_partition,
  126. input_size=input_size,
  127. group_size=group_size)
  128. # Detect sharding of scales/zp
  129. # By default, no sharding over "input dim"
  130. scales_and_zp_size = input_size // group_size
  131. scales_and_zp_input_dim = None
  132. if self.quant_config.desc_act:
  133. # Act-order case
  134. assert self.quant_config.group_size != -1
  135. is_k_full = input_size_per_partition == input_size
  136. else:
  137. # No act-order case
  138. # K is always full due to full alignment with
  139. # group-size and shard of scales/zp
  140. is_k_full = True
  141. # If this is a row-parallel case, then shard scales/zp
  142. if (input_size != input_size_per_partition
  143. and self.quant_config.group_size != -1):
  144. scales_and_zp_size = input_size_per_partition // group_size
  145. scales_and_zp_input_dim = 0
  146. # Init buffers
  147. # Quantized weights
  148. qweight = Parameter(
  149. torch.empty(
  150. input_size_per_partition // self.quant_config.pack_factor,
  151. output_size_per_partition,
  152. dtype=torch.int32,
  153. ),
  154. requires_grad=False,
  155. )
  156. set_weight_attrs(
  157. qweight,
  158. {
  159. **extra_weight_attrs,
  160. "input_dim": 0,
  161. "output_dim": 1,
  162. "packed_dim": 0,
  163. "pack_factor": self.quant_config.pack_factor,
  164. },
  165. )
  166. # Activation order
  167. g_idx = Parameter(
  168. torch.empty(
  169. input_size_per_partition,
  170. dtype=torch.int32,
  171. ),
  172. requires_grad=False,
  173. )
  174. # Ignore warning from fused linear layers such as QKVParallelLinear.
  175. set_weight_attrs(
  176. g_idx,
  177. {
  178. **extra_weight_attrs, "input_dim": 0,
  179. "ignore_warning": True
  180. },
  181. )
  182. # Scales
  183. scales = Parameter(
  184. torch.empty(
  185. scales_and_zp_size,
  186. output_size_per_partition,
  187. dtype=params_dtype,
  188. ),
  189. requires_grad=False,
  190. )
  191. set_weight_attrs(
  192. scales,
  193. {
  194. **extra_weight_attrs,
  195. "input_dim": scales_and_zp_input_dim,
  196. "output_dim": 1,
  197. },
  198. )
  199. # Quantized zero-points
  200. qzeros = Parameter(
  201. torch.empty(
  202. scales_and_zp_size,
  203. output_size_per_partition // self.quant_config.pack_factor,
  204. dtype=torch.int32,
  205. device="meta",
  206. ),
  207. requires_grad=False,
  208. )
  209. set_weight_attrs(
  210. qzeros,
  211. {
  212. **extra_weight_attrs,
  213. "input_dim": scales_and_zp_input_dim,
  214. "output_dim": 1,
  215. "packed_dim": 1,
  216. "pack_factor": self.quant_config.pack_factor,
  217. },
  218. )
  219. layer.register_parameter("qweight", qweight)
  220. layer.register_parameter("g_idx", g_idx)
  221. layer.register_parameter("scales", scales)
  222. layer.register_parameter("qzeros", qzeros)
  223. layer.input_size_per_partition = input_size_per_partition
  224. layer.output_size_per_partition = output_size_per_partition
  225. layer.input_size = input_size
  226. layer.is_k_full = is_k_full
  227. # Checkpoints are serialized in AutoGPTQ format, which is different from the
  228. # marlin format. This function is called after the weights are loaded.
  229. # Here, we handle the repacking, including the activation reordering case.
  230. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  231. device = layer.qweight.device
  232. # Allocate marlin workspace
  233. layer.workspace = marlin_make_workspace(
  234. layer.output_size_per_partition, device)
  235. # Handle sorting for activation reordering if needed.
  236. if self.quant_config.desc_act:
  237. g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
  238. layer.g_idx_sort_indices = g_idx_sort_indices
  239. replace_tensor(layer, "g_idx", g_idx)
  240. else:
  241. layer.g_idx = marlin_make_empty_g_idx(device)
  242. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  243. # Repack weights from autogptq format to marlin format.
  244. marlin_qweight = ops.gptq_marlin_repack(
  245. layer.qweight,
  246. perm=layer.g_idx_sort_indices,
  247. size_k=layer.input_size_per_partition,
  248. size_n=layer.output_size_per_partition,
  249. num_bits=self.quant_config.weight_bits)
  250. replace_tensor(layer, "qweight", marlin_qweight)
  251. # Permute scales from autogptq format to marlin format.
  252. marlin_scales = marlin_permute_scales(
  253. layer.scales,
  254. size_k=(layer.input_size if self.quant_config.desc_act else
  255. layer.input_size_per_partition),
  256. size_n=layer.output_size_per_partition,
  257. group_size=self.quant_config.group_size)
  258. replace_tensor(layer, "scales", marlin_scales)
  259. def apply(
  260. self,
  261. layer: torch.nn.Module,
  262. x: torch.Tensor,
  263. bias: Optional[torch.Tensor] = None,
  264. ) -> torch.Tensor:
  265. reshaped_x = x.reshape(-1, x.shape[-1])
  266. out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
  267. output = ops.gptq_marlin_gemm(reshaped_x,
  268. layer.qweight,
  269. layer.scales,
  270. g_idx=layer.g_idx,
  271. perm=layer.g_idx_sort_indices,
  272. workspace=layer.workspace,
  273. num_bits=self.quant_config.weight_bits,
  274. size_m=reshaped_x.shape[0],
  275. size_n=layer.output_size_per_partition,
  276. size_k=layer.input_size_per_partition,
  277. is_k_full=layer.is_k_full)
  278. if bias is not None:
  279. output.add_(bias) # In-place add
  280. return output.reshape(out_shape)