gptq_marlin.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. import enum
  2. from enum import Enum
  3. from typing import Any, Dict, List, Optional
  4. import torch
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter
  7. from aphrodite import _custom_ops as ops
  8. from aphrodite.platforms import current_platform
  9. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  10. set_weight_attrs)
  11. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  12. from aphrodite.quantization.base_config import QuantizationConfig
  13. GPTQ_MARLIN_TILE = 16
  14. GPTQ_MARLIN_MIN_THREAD_N = 64
  15. GPTQ_MARLIN_MIN_THREAD_K = 128
  16. GPTQ_MARLIN_MAX_PARALLEL = 16
  17. GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
  18. GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  19. GPTQ_MARLIN_SUPPORTED_SYM = [True]
  20. # Permutations for Marlin scale shuffling
  21. def get_scale_perms(num_bits):
  22. scale_perm = []
  23. for i in range(8):
  24. scale_perm.extend([i + 8 * j for j in range(8)])
  25. scale_perm_single = []
  26. for i in range(4):
  27. scale_perm_single.extend(
  28. [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
  29. return scale_perm, scale_perm_single
  30. def get_pack_factor(num_bits):
  31. assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
  32. ), f"Unsupported num_bits = {num_bits}"
  33. return 32 // num_bits
  34. def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
  35. scale_perm, scale_perm_single = get_scale_perms(num_bits)
  36. if group_size < size_k and group_size != -1:
  37. s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
  38. else:
  39. s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
  40. s = s.reshape((-1, size_n)).contiguous()
  41. return s
  42. class GPTQMarlinConfig(QuantizationConfig):
  43. """Config class for GPTQ Marlin"""
  44. def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
  45. is_sym: bool, lm_head_quantized: bool) -> None:
  46. if desc_act and group_size == -1:
  47. # In this case, act_order == True is the same as act_order == False
  48. # (since we have only one group per output channel)
  49. desc_act = False
  50. self.weight_bits = weight_bits
  51. self.group_size = group_size
  52. self.desc_act = desc_act
  53. self.is_sym = is_sym
  54. self.lm_head_quantized = lm_head_quantized
  55. # Verify
  56. if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
  57. raise ValueError(
  58. f"Marlin does not support weight_bits = {self.weight_bits}. "
  59. f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
  60. "are supported.")
  61. if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
  62. raise ValueError(
  63. f"Marlin does not support group_size = {self.group_size}. "
  64. f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
  65. "are supported.")
  66. if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
  67. raise ValueError(
  68. f"Marlin does not support is_sym = {self.is_sym}. "
  69. f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
  70. # Init
  71. self.pack_factor = get_pack_factor(weight_bits)
  72. self.tile_size = GPTQ_MARLIN_TILE
  73. self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
  74. self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
  75. self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
  76. def __repr__(self) -> str:
  77. return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
  78. f"group_size={self.group_size}, "
  79. f"desc_act={self.desc_act}, "
  80. f"lm_head_quantized={self.lm_head_quantized})")
  81. @classmethod
  82. def get_name(cls) -> str:
  83. return "gptq_marlin"
  84. @classmethod
  85. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  86. return [torch.half, torch.bfloat16]
  87. @classmethod
  88. def get_min_capability(cls) -> int:
  89. return 80
  90. @classmethod
  91. def get_config_filenames(cls) -> List[str]:
  92. return ["quantize_config.json"]
  93. @classmethod
  94. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
  95. weight_bits = cls.get_from_keys(config, ["bits"])
  96. group_size = cls.get_from_keys(config, ["group_size"])
  97. desc_act = cls.get_from_keys(config, ["desc_act"])
  98. is_sym = cls.get_from_keys(config, ["sym"])
  99. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  100. default=False)
  101. return cls(weight_bits, group_size, desc_act, is_sym,
  102. lm_head_quantized)
  103. @classmethod
  104. def override_quantization_method(cls, hf_quant_cfg,
  105. user_quant) -> Optional[str]:
  106. can_convert = cls.is_marlin_compatible(hf_quant_cfg)
  107. is_valid_user_quant = (user_quant is None or user_quant == "marlin")
  108. if can_convert and is_valid_user_quant:
  109. msg = ("The model is convertible to {} during runtime."
  110. " Using {} kernel.".format(cls.get_name(), cls.get_name()))
  111. logger.info(msg)
  112. return cls.get_name()
  113. if can_convert and user_quant == "gptq":
  114. logger.info("Detected that the model can run with gptq_marlin"
  115. ", however you specified quantization=gptq explicitly,"
  116. " so forcing gptq. Use quantization=gptq_marlin for"
  117. " faster inference")
  118. return None
  119. def get_quant_method(
  120. self,
  121. layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
  122. if (isinstance(layer, LinearBase) or
  123. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  124. return GPTQMarlinLinearMethod(self)
  125. return None
  126. def get_scaled_act_names(self) -> List[str]:
  127. return []
  128. @classmethod
  129. def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
  130. # Extract data from quant config.
  131. num_bits = quant_config.get("bits", None)
  132. group_size = quant_config.get("group_size", None)
  133. sym = quant_config.get("sym", None)
  134. desc_act = quant_config.get("desc_act", None)
  135. # If we cannot find the info needed in the config, cannot convert.
  136. if (num_bits is None or group_size is None or sym is None
  137. or desc_act is None):
  138. return False
  139. # If the capability of the device is too low, cannot convert.
  140. major, minor = current_platform.get_device_capability()
  141. device_capability = major * 10 + minor
  142. if device_capability < cls.get_min_capability():
  143. return False
  144. # Otherwise, can convert if model satisfies marlin constraints.
  145. return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
  146. and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
  147. and sym in GPTQ_MARLIN_SUPPORTED_SYM)
  148. class GPTQMarlinState(Enum):
  149. REPACK = enum.auto()
  150. READY = enum.auto()
  151. class GPTQMarlinLinearMethod(LinearMethodBase):
  152. """Linear method for GPTQ Marlin.
  153. Args:
  154. quant_config: The GPTQ Marlin quantization config.
  155. """
  156. def __init__(self, quant_config: GPTQMarlinConfig) -> None:
  157. self.quant_config = quant_config
  158. def create_weights(
  159. self,
  160. layer: torch.nn.Module,
  161. input_size_per_partition: int,
  162. output_partition_sizes: List[int],
  163. input_size: int,
  164. output_size: int,
  165. params_dtype: torch.dtype,
  166. **extra_weight_attrs,
  167. ) -> None:
  168. del output_size
  169. # Normalize group_size
  170. if self.quant_config.group_size != -1:
  171. group_size = self.quant_config.group_size
  172. else:
  173. group_size = input_size
  174. # Validate dtype
  175. if params_dtype not in [torch.float16, torch.bfloat16]:
  176. raise ValueError(f"The params dtype must be float16 "
  177. f"or bfloat16, but got {params_dtype}")
  178. # Validate output_size_per_partition
  179. output_size_per_partition = sum(output_partition_sizes)
  180. if output_size_per_partition % self.quant_config.min_thread_n != 0:
  181. raise ValueError(
  182. f"Weight output_size_per_partition = "
  183. f"{output_size_per_partition} is not divisible by "
  184. f" min_thread_n = {self.quant_config.min_thread_n}.")
  185. # Validate input_size_per_partition
  186. if input_size_per_partition % self.quant_config.min_thread_k != 0:
  187. raise ValueError(
  188. f"Weight input_size_per_partition = "
  189. f"{input_size_per_partition} is not divisible "
  190. f"by min_thread_k = {self.quant_config.min_thread_k}.")
  191. if (group_size < input_size
  192. and input_size_per_partition % group_size != 0):
  193. raise ValueError(
  194. f"Weight input_size_per_partition = {input_size_per_partition}"
  195. f" is not divisible by group_size = {group_size}.")
  196. # Detect sharding of scales/zp
  197. # By default, no sharding over "input dim"
  198. scales_and_zp_size = input_size // group_size
  199. scales_and_zp_input_dim = None
  200. if self.quant_config.desc_act:
  201. # Act-order case
  202. assert self.quant_config.group_size != -1
  203. is_k_full = input_size_per_partition == input_size
  204. else:
  205. # No act-order case
  206. # K is always full due to full alignment with
  207. # group-size and shard of scales/zp
  208. is_k_full = True
  209. # If this is a row-parallel case, then shard scales/zp
  210. if (input_size != input_size_per_partition
  211. and self.quant_config.group_size != -1):
  212. scales_and_zp_size = input_size_per_partition // group_size
  213. scales_and_zp_input_dim = 0
  214. # Init buffers
  215. # Quantized weights
  216. qweight = Parameter(
  217. torch.empty(
  218. input_size_per_partition // self.quant_config.pack_factor,
  219. output_size_per_partition,
  220. dtype=torch.int32,
  221. ),
  222. requires_grad=False,
  223. )
  224. set_weight_attrs(
  225. qweight, {
  226. **extra_weight_attrs,
  227. "input_dim": 0,
  228. "output_dim": 1,
  229. "packed_dim": 0,
  230. "pack_factor": self.quant_config.pack_factor,
  231. })
  232. # Activation order
  233. g_idx = Parameter(
  234. torch.empty(
  235. input_size_per_partition,
  236. dtype=torch.int32,
  237. ),
  238. requires_grad=False,
  239. )
  240. # Ignore warning from fused linear layers such as QKVParallelLinear.
  241. set_weight_attrs(g_idx, {
  242. **extra_weight_attrs, "input_dim": 0,
  243. "ignore_warning": True
  244. })
  245. g_idx_sort_indices = torch.empty(
  246. g_idx.shape,
  247. dtype=torch.int32,
  248. )
  249. # Scales
  250. scales = Parameter(
  251. torch.empty(
  252. scales_and_zp_size,
  253. output_size_per_partition,
  254. dtype=params_dtype,
  255. ),
  256. requires_grad=False,
  257. )
  258. set_weight_attrs(
  259. scales, {
  260. **extra_weight_attrs,
  261. "input_dim": scales_and_zp_input_dim,
  262. "output_dim": 1,
  263. })
  264. # Quantized zero-points
  265. qzeros = Parameter(
  266. torch.empty(scales_and_zp_size,
  267. output_size_per_partition //
  268. self.quant_config.pack_factor,
  269. dtype=torch.int32,
  270. device="meta"),
  271. requires_grad=False,
  272. )
  273. set_weight_attrs(
  274. qzeros, {
  275. **extra_weight_attrs,
  276. "input_dim": scales_and_zp_input_dim,
  277. "output_dim": 1,
  278. "packed_dim": 1,
  279. "pack_factor": self.quant_config.pack_factor,
  280. })
  281. # Allocate marlin workspace
  282. max_workspace_size = (
  283. output_size_per_partition //
  284. self.quant_config.min_thread_n) * self.quant_config.max_parallel
  285. workspace = torch.zeros(max_workspace_size,
  286. dtype=torch.int,
  287. requires_grad=False)
  288. layer.register_parameter("qweight", qweight)
  289. layer.register_parameter("g_idx", g_idx)
  290. layer.register_parameter("scales", scales)
  291. layer.register_parameter("qzeros", qzeros)
  292. layer.g_idx_sort_indices = g_idx_sort_indices
  293. layer.workspace = workspace
  294. layer.input_size_per_partition = input_size_per_partition
  295. layer.output_size_per_partition = output_size_per_partition
  296. layer.input_size = input_size
  297. layer.is_k_full = is_k_full
  298. layer.marlin_state = GPTQMarlinState.REPACK
  299. def apply(
  300. self,
  301. layer: torch.nn.Module,
  302. x: torch.Tensor,
  303. bias: Optional[torch.Tensor] = None,
  304. ) -> torch.Tensor:
  305. reshaped_x = x.reshape(-1, x.shape[-1])
  306. size_m = reshaped_x.shape[0]
  307. part_size_n = layer.output_size_per_partition
  308. part_size_k = layer.input_size_per_partition
  309. full_size_k = layer.input_size
  310. out_shape = x.shape[:-1] + (part_size_n, )
  311. if layer.marlin_state == GPTQMarlinState.REPACK:
  312. layer.marlin_state = GPTQMarlinState.READY
  313. # Newly generated tensors need to replace existing tensors that are
  314. # already registered as parameters by Aphrodite (and won't be
  315. # freed)
  316. def replace_tensor(name, new_t):
  317. # It is important to use resize_() here since it ensures
  318. # the same buffer is reused
  319. getattr(layer, name).resize_(new_t.shape)
  320. getattr(layer, name).copy_(new_t)
  321. del new_t
  322. cur_device = layer.qweight.device
  323. # Process act_order
  324. if self.quant_config.desc_act:
  325. # Get sorting based on g_idx
  326. g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
  327. sorted_g_idx = layer.g_idx[g_idx_sort_indices]
  328. replace_tensor("g_idx", sorted_g_idx)
  329. replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
  330. else:
  331. # Reset g_idx related tensors
  332. layer.g_idx = Parameter(torch.empty(0,
  333. dtype=torch.int,
  334. device=cur_device),
  335. requires_grad=False)
  336. layer.g_idx_sort_indices = Parameter(torch.empty(
  337. 0, dtype=torch.int, device=cur_device),
  338. requires_grad=False)
  339. # Repack weights
  340. marlin_qweight = ops.gptq_marlin_repack(
  341. layer.qweight,
  342. layer.g_idx_sort_indices,
  343. part_size_k,
  344. part_size_n,
  345. self.quant_config.weight_bits,
  346. )
  347. replace_tensor("qweight", marlin_qweight)
  348. # Permute scales
  349. scales_size_k = part_size_k
  350. scales_size_n = part_size_n
  351. if self.quant_config.desc_act:
  352. scales_size_k = full_size_k
  353. marlin_scales = marlin_permute_scales(
  354. layer.scales,
  355. scales_size_k,
  356. scales_size_n,
  357. self.quant_config.group_size,
  358. self.quant_config.weight_bits,
  359. )
  360. replace_tensor("scales", marlin_scales)
  361. output = ops.gptq_marlin_gemm(
  362. reshaped_x,
  363. layer.qweight,
  364. layer.scales,
  365. layer.g_idx,
  366. layer.g_idx_sort_indices,
  367. layer.workspace,
  368. self.quant_config.weight_bits,
  369. size_m,
  370. part_size_n,
  371. part_size_k,
  372. layer.is_k_full,
  373. )
  374. if bias is not None:
  375. output.add_(bias) # In-place add
  376. return output.reshape(out_shape)