gptq_marlin.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. import enum
  2. from contextlib import suppress
  3. from enum import Enum
  4. from typing import Any, Dict, List, Optional
  5. import torch
  6. from loguru import logger
  7. from torch.nn.parameter import Parameter
  8. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  9. set_weight_attrs)
  10. from aphrodite.quantization.base_config import QuantizationConfig
  11. HAS_QUANTS = False
  12. with suppress(ImportError):
  13. from aphrodite._quant_C import quant_ops as ops
  14. HAS_QUANTS = True
  15. GPTQ_MARLIN_TILE = 16
  16. GPTQ_MARLIN_MIN_THREAD_N = 64
  17. GPTQ_MARLIN_MIN_THREAD_K = 128
  18. GPTQ_MARLIN_MAX_PARALLEL = 16
  19. GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
  20. GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  21. GPTQ_MARLIN_SUPPORTED_SYM = [True]
  22. # Permutations for Marlin scale shuffling
  23. def get_scale_perms(num_bits):
  24. scale_perm = []
  25. for i in range(8):
  26. scale_perm.extend([i + 8 * j for j in range(8)])
  27. scale_perm_single = []
  28. for i in range(4):
  29. scale_perm_single.extend(
  30. [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
  31. return scale_perm, scale_perm_single
  32. def get_pack_factor(num_bits):
  33. assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
  34. ), f"Unsupported num_bits = {num_bits}"
  35. return 32 // num_bits
  36. def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
  37. scale_perm, scale_perm_single = get_scale_perms(num_bits)
  38. if group_size < size_k and group_size != -1:
  39. s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
  40. else:
  41. s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
  42. s = s.reshape((-1, size_n)).contiguous()
  43. return s
  44. class GPTQMarlinConfig(QuantizationConfig):
  45. """Config class for GPTQ Marlin"""
  46. def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
  47. is_sym: bool) -> None:
  48. if desc_act and group_size == -1:
  49. # In this case, act_order == True is the same as act_order == False
  50. # (since we have only one group per output channel)
  51. desc_act = False
  52. self.weight_bits = weight_bits
  53. self.group_size = group_size
  54. self.desc_act = desc_act
  55. self.is_sym = is_sym
  56. # Verify
  57. if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
  58. raise ValueError(
  59. f"Marlin does not support weight_bits = {self.weight_bits}. "
  60. f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
  61. "are supported.")
  62. if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
  63. raise ValueError(
  64. f"Marlin does not support group_size = {self.group_size}. "
  65. f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
  66. "are supported.")
  67. if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
  68. raise ValueError(
  69. f"Marlin does not support is_sym = {self.is_sym}. "
  70. f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
  71. # Init
  72. self.pack_factor = get_pack_factor(weight_bits)
  73. self.tile_size = GPTQ_MARLIN_TILE
  74. self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
  75. self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
  76. self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
  77. def __repr__(self) -> str:
  78. return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
  79. f"group_size={self.group_size}, "
  80. f"desc_act={self.desc_act})")
  81. @classmethod
  82. def get_name(cls) -> str:
  83. return "gptq_marlin"
  84. @classmethod
  85. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  86. return [torch.half, torch.bfloat16]
  87. @classmethod
  88. def get_min_capability(cls) -> int:
  89. return 80
  90. @classmethod
  91. def get_config_filenames(cls) -> List[str]:
  92. return ["quantize_config.json"]
  93. @classmethod
  94. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
  95. weight_bits = cls.get_from_keys(config, ["bits"])
  96. group_size = cls.get_from_keys(config, ["group_size"])
  97. desc_act = cls.get_from_keys(config, ["desc_act"])
  98. is_sym = cls.get_from_keys(config, ["sym"])
  99. return cls(weight_bits, group_size, desc_act, is_sym)
  100. @classmethod
  101. def override_quantization_method(cls, hf_quant_cfg,
  102. user_quant) -> Optional[str]:
  103. can_convert = cls.is_marlin_compatible(hf_quant_cfg)
  104. is_valid_user_quant = (user_quant is None or user_quant == "marlin")
  105. if can_convert and is_valid_user_quant:
  106. msg = ("The model is convertible to {} during runtime."
  107. " Using {} kernel.".format(cls.get_name(), cls.get_name()))
  108. logger.info(msg)
  109. return cls.get_name()
  110. if can_convert and user_quant == "gptq":
  111. logger.info("Detected that the model can run with gptq_marlin"
  112. ", however you specified quantization=gptq explicitly,"
  113. " so forcing gptq. Use quantization=gptq_marlin for"
  114. " faster inference")
  115. return None
  116. def get_quant_method(
  117. self,
  118. layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
  119. if isinstance(layer, LinearBase):
  120. return GPTQMarlinLinearMethod(self)
  121. return None
  122. def get_scaled_act_names(self) -> List[str]:
  123. return []
  124. @classmethod
  125. def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
  126. # Extract data from quant config.
  127. num_bits = quant_config.get("bits", None)
  128. group_size = quant_config.get("group_size", None)
  129. sym = quant_config.get("sym", None)
  130. desc_act = quant_config.get("desc_act", None)
  131. # If we cannot find the info needed in the config, cannot convert.
  132. if (num_bits is None or group_size is None or sym is None
  133. or desc_act is None):
  134. return False
  135. # If the capability of the device is too low, cannot convert.
  136. major, minor = torch.cuda.get_device_capability()
  137. device_capability = major * 10 + minor
  138. if device_capability < cls.get_min_capability():
  139. return False
  140. # Otherwise, can convert if model satisfies marlin constraints.
  141. return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
  142. and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
  143. and sym in GPTQ_MARLIN_SUPPORTED_SYM)
  144. class GPTQMarlinState(Enum):
  145. REPACK = enum.auto()
  146. READY = enum.auto()
  147. class GPTQMarlinLinearMethod(LinearMethodBase):
  148. """Linear method for GPTQ Marlin.
  149. Args:
  150. quant_config: The GPTQ Marlin quantization config.
  151. """
  152. def __init__(self, quant_config: GPTQMarlinConfig) -> None:
  153. if not HAS_QUANTS:
  154. raise ImportError("Could not find the quantization kernels.")
  155. self.quant_config = quant_config
  156. def create_weights(
  157. self,
  158. layer: torch.nn.Module,
  159. input_size_per_partition: int,
  160. output_partition_sizes: List[int],
  161. input_size: int,
  162. output_size: int,
  163. params_dtype: torch.dtype,
  164. **extra_weight_attrs,
  165. ) -> None:
  166. del output_size
  167. # Normalize group_size
  168. if self.quant_config.group_size != -1:
  169. group_size = self.quant_config.group_size
  170. else:
  171. group_size = input_size
  172. # Validate dtype
  173. if params_dtype not in [torch.float16, torch.bfloat16]:
  174. raise ValueError(f"The params dtype must be float16 "
  175. f"or bfloat16, but got {params_dtype}")
  176. # Validate output_size_per_partition
  177. output_size_per_partition = sum(output_partition_sizes)
  178. if output_size_per_partition % self.quant_config.min_thread_n != 0:
  179. raise ValueError(
  180. f"Weight output_size_per_partition = "
  181. f"{output_size_per_partition} is not divisible by "
  182. f" min_thread_n = {self.quant_config.min_thread_n}.")
  183. # Validate input_size_per_partition
  184. if input_size_per_partition % self.quant_config.min_thread_k != 0:
  185. raise ValueError(
  186. f"Weight input_size_per_partition = "
  187. f"{input_size_per_partition} is not divisible "
  188. f"by min_thread_k = {self.quant_config.min_thread_k}.")
  189. if (group_size < input_size
  190. and input_size_per_partition % group_size != 0):
  191. raise ValueError(
  192. f"Weight input_size_per_partition = {input_size_per_partition}"
  193. f" is not divisible by group_size = {group_size}.")
  194. # Detect sharding of scales/zp
  195. # By default, no sharding over "input dim"
  196. scales_and_zp_size = input_size // group_size
  197. scales_and_zp_input_dim = None
  198. if self.quant_config.desc_act:
  199. # Act-order case
  200. assert self.quant_config.group_size != -1
  201. is_k_full = input_size_per_partition == input_size
  202. else:
  203. # No act-order case
  204. # K is always full due to full alignment with
  205. # group-size and shard of scales/zp
  206. is_k_full = True
  207. # If this is a row-parallel case, then shard scales/zp
  208. if (input_size != input_size_per_partition
  209. and self.quant_config.group_size != -1):
  210. scales_and_zp_size = input_size_per_partition // group_size
  211. scales_and_zp_input_dim = 0
  212. # Init buffers
  213. # Quantized weights
  214. qweight = Parameter(
  215. torch.empty(
  216. input_size_per_partition // self.quant_config.pack_factor,
  217. output_size_per_partition,
  218. dtype=torch.int32,
  219. ),
  220. requires_grad=False,
  221. )
  222. set_weight_attrs(
  223. qweight, {
  224. **extra_weight_attrs,
  225. "input_dim": 0,
  226. "output_dim": 1,
  227. "packed_dim": 0,
  228. "pack_factor": self.quant_config.pack_factor,
  229. })
  230. # Activation order
  231. g_idx = Parameter(
  232. torch.empty(
  233. input_size_per_partition,
  234. dtype=torch.int32,
  235. ),
  236. requires_grad=False,
  237. )
  238. # Ignore warning from fused linear layers such as QKVParallelLinear.
  239. set_weight_attrs(g_idx, {
  240. **extra_weight_attrs, "input_dim": 0,
  241. "ignore_warning": True
  242. })
  243. g_idx_sort_indices = Parameter(
  244. torch.empty(
  245. g_idx.shape,
  246. dtype=torch.int32,
  247. ),
  248. requires_grad=False,
  249. )
  250. set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)
  251. # Scales
  252. scales = Parameter(
  253. torch.empty(
  254. scales_and_zp_size,
  255. output_size_per_partition,
  256. dtype=params_dtype,
  257. ),
  258. requires_grad=False,
  259. )
  260. set_weight_attrs(
  261. scales, {
  262. **extra_weight_attrs,
  263. "input_dim": scales_and_zp_input_dim,
  264. "output_dim": 1,
  265. })
  266. # Quantized zero-points
  267. qzeros = Parameter(
  268. torch.empty(scales_and_zp_size,
  269. output_size_per_partition //
  270. self.quant_config.pack_factor,
  271. dtype=torch.int32,
  272. device="meta"),
  273. requires_grad=False,
  274. )
  275. set_weight_attrs(
  276. qzeros, {
  277. **extra_weight_attrs,
  278. "input_dim": scales_and_zp_input_dim,
  279. "output_dim": 1,
  280. "packed_dim": 1,
  281. "pack_factor": self.quant_config.pack_factor,
  282. })
  283. # Allocate marlin workspace
  284. max_workspace_size = (
  285. output_size_per_partition //
  286. self.quant_config.min_thread_n) * self.quant_config.max_parallel
  287. workspace = torch.zeros(max_workspace_size,
  288. dtype=torch.int,
  289. requires_grad=False)
  290. layer.register_parameter("qweight", qweight)
  291. layer.register_parameter("g_idx", g_idx)
  292. layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
  293. layer.register_parameter("scales", scales)
  294. layer.register_parameter("qzeros", qzeros)
  295. layer.workspace = workspace
  296. layer.input_size_per_partition = input_size_per_partition
  297. layer.output_size_per_partition = output_size_per_partition
  298. layer.input_size = input_size
  299. layer.is_k_full = is_k_full
  300. layer.marlin_state = GPTQMarlinState.REPACK
  301. def apply(
  302. self,
  303. layer: torch.nn.Module,
  304. x: torch.Tensor,
  305. bias: Optional[torch.Tensor] = None,
  306. ) -> torch.Tensor:
  307. reshaped_x = x.reshape(-1, x.shape[-1])
  308. size_m = reshaped_x.shape[0]
  309. part_size_n = layer.output_size_per_partition
  310. part_size_k = layer.input_size_per_partition
  311. full_size_k = layer.input_size
  312. out_shape = x.shape[:-1] + (part_size_n, )
  313. if layer.marlin_state == GPTQMarlinState.REPACK:
  314. layer.marlin_state = GPTQMarlinState.READY
  315. # Newly generated tensors need to replace existing tensors that are
  316. # already registered as parameters by Aphrodite (and won't be
  317. # freed)
  318. def replace_tensor(name, new_t):
  319. # It is important to use resize_() here since it ensures
  320. # the same buffer is reused
  321. getattr(layer, name).resize_(new_t.shape)
  322. getattr(layer, name).copy_(new_t)
  323. del new_t
  324. cur_device = layer.qweight.device
  325. # Process act_order
  326. if self.quant_config.desc_act:
  327. # Get sorting based on g_idx
  328. g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
  329. sorted_g_idx = layer.g_idx[g_idx_sort_indices]
  330. replace_tensor("g_idx", sorted_g_idx)
  331. replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
  332. else:
  333. # Reset g_idx related tensors
  334. layer.g_idx = Parameter(torch.empty(0,
  335. dtype=torch.int,
  336. device=cur_device),
  337. requires_grad=False)
  338. layer.g_idx_sort_indices = Parameter(torch.empty(
  339. 0, dtype=torch.int, device=cur_device),
  340. requires_grad=False)
  341. # Repack weights
  342. marlin_qweight = ops.gptq_marlin_repack(
  343. layer.qweight,
  344. layer.g_idx_sort_indices,
  345. part_size_k,
  346. part_size_n,
  347. self.quant_config.weight_bits,
  348. )
  349. replace_tensor("qweight", marlin_qweight)
  350. # Permute scales
  351. scales_size_k = part_size_k
  352. scales_size_n = part_size_n
  353. if self.quant_config.desc_act:
  354. scales_size_k = full_size_k
  355. marlin_scales = marlin_permute_scales(
  356. layer.scales,
  357. scales_size_k,
  358. scales_size_n,
  359. self.quant_config.group_size,
  360. self.quant_config.weight_bits,
  361. )
  362. replace_tensor("scales", marlin_scales)
  363. output = ops.gptq_marlin_gemm(
  364. reshaped_x,
  365. layer.qweight,
  366. layer.scales,
  367. layer.g_idx,
  368. layer.g_idx_sort_indices,
  369. layer.workspace,
  370. self.quant_config.weight_bits,
  371. size_m,
  372. part_size_n,
  373. part_size_k,
  374. layer.is_k_full,
  375. )
  376. if bias is not None:
  377. output.add_(bias) # In-place add
  378. return output.reshape(out_shape)