gptq_marlin.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. import enum
  2. from enum import Enum
  3. from typing import Any, Dict, List, Optional
  4. import torch
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter
  7. from aphrodite import _custom_ops as ops
  8. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  9. set_weight_attrs)
  10. from aphrodite.quantization.base_config import QuantizationConfig
  11. GPTQ_MARLIN_TILE = 16
  12. GPTQ_MARLIN_MIN_THREAD_N = 64
  13. GPTQ_MARLIN_MIN_THREAD_K = 128
  14. GPTQ_MARLIN_MAX_PARALLEL = 16
  15. GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
  16. GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  17. GPTQ_MARLIN_SUPPORTED_SYM = [True]
  18. # Permutations for Marlin scale shuffling
  19. def get_scale_perms(num_bits):
  20. scale_perm = []
  21. for i in range(8):
  22. scale_perm.extend([i + 8 * j for j in range(8)])
  23. scale_perm_single = []
  24. for i in range(4):
  25. scale_perm_single.extend(
  26. [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
  27. return scale_perm, scale_perm_single
  28. def get_pack_factor(num_bits):
  29. assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
  30. ), f"Unsupported num_bits = {num_bits}"
  31. return 32 // num_bits
  32. def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
  33. scale_perm, scale_perm_single = get_scale_perms(num_bits)
  34. if group_size < size_k and group_size != -1:
  35. s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
  36. else:
  37. s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
  38. s = s.reshape((-1, size_n)).contiguous()
  39. return s
  40. class GPTQMarlinConfig(QuantizationConfig):
  41. """Config class for GPTQ Marlin"""
  42. def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
  43. is_sym: bool) -> None:
  44. if desc_act and group_size == -1:
  45. # In this case, act_order == True is the same as act_order == False
  46. # (since we have only one group per output channel)
  47. desc_act = False
  48. self.weight_bits = weight_bits
  49. self.group_size = group_size
  50. self.desc_act = desc_act
  51. self.is_sym = is_sym
  52. # Verify
  53. if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
  54. raise ValueError(
  55. f"Marlin does not support weight_bits = {self.weight_bits}. "
  56. f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
  57. "are supported.")
  58. if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
  59. raise ValueError(
  60. f"Marlin does not support group_size = {self.group_size}. "
  61. f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
  62. "are supported.")
  63. if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
  64. raise ValueError(
  65. f"Marlin does not support is_sym = {self.is_sym}. "
  66. f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
  67. # Init
  68. self.pack_factor = get_pack_factor(weight_bits)
  69. self.tile_size = GPTQ_MARLIN_TILE
  70. self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
  71. self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
  72. self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
  73. def __repr__(self) -> str:
  74. return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
  75. f"group_size={self.group_size}, "
  76. f"desc_act={self.desc_act})")
  77. @classmethod
  78. def get_name(cls) -> str:
  79. return "gptq_marlin"
  80. @classmethod
  81. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  82. return [torch.half, torch.bfloat16]
  83. @classmethod
  84. def get_min_capability(cls) -> int:
  85. return 80
  86. @classmethod
  87. def get_config_filenames(cls) -> List[str]:
  88. return ["quantize_config.json"]
  89. @classmethod
  90. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
  91. weight_bits = cls.get_from_keys(config, ["bits"])
  92. group_size = cls.get_from_keys(config, ["group_size"])
  93. desc_act = cls.get_from_keys(config, ["desc_act"])
  94. is_sym = cls.get_from_keys(config, ["sym"])
  95. return cls(weight_bits, group_size, desc_act, is_sym)
  96. @classmethod
  97. def override_quantization_method(cls, hf_quant_cfg,
  98. user_quant) -> Optional[str]:
  99. can_convert = cls.is_marlin_compatible(hf_quant_cfg)
  100. is_valid_user_quant = (user_quant is None or user_quant == "marlin")
  101. if can_convert and is_valid_user_quant:
  102. msg = ("The model is convertible to {} during runtime."
  103. " Using {} kernel.".format(cls.get_name(), cls.get_name()))
  104. logger.info(msg)
  105. return cls.get_name()
  106. if can_convert and user_quant == "gptq":
  107. logger.info("Detected that the model can run with gptq_marlin"
  108. ", however you specified quantization=gptq explicitly,"
  109. " so forcing gptq. Use quantization=gptq_marlin for"
  110. " faster inference")
  111. return None
  112. def get_quant_method(
  113. self,
  114. layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
  115. if isinstance(layer, LinearBase):
  116. return GPTQMarlinLinearMethod(self)
  117. return None
  118. def get_scaled_act_names(self) -> List[str]:
  119. return []
  120. @classmethod
  121. def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
  122. # Extract data from quant config.
  123. num_bits = quant_config.get("bits", None)
  124. group_size = quant_config.get("group_size", None)
  125. sym = quant_config.get("sym", None)
  126. desc_act = quant_config.get("desc_act", None)
  127. # If we cannot find the info needed in the config, cannot convert.
  128. if (num_bits is None or group_size is None or sym is None
  129. or desc_act is None):
  130. return False
  131. # If the capability of the device is too low, cannot convert.
  132. major, minor = torch.cuda.get_device_capability()
  133. device_capability = major * 10 + minor
  134. if device_capability < cls.get_min_capability():
  135. return False
  136. # Otherwise, can convert if model satisfies marlin constraints.
  137. return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
  138. and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
  139. and sym in GPTQ_MARLIN_SUPPORTED_SYM)
  140. class GPTQMarlinState(Enum):
  141. REPACK = enum.auto()
  142. READY = enum.auto()
  143. class GPTQMarlinLinearMethod(LinearMethodBase):
  144. """Linear method for GPTQ Marlin.
  145. Args:
  146. quant_config: The GPTQ Marlin quantization config.
  147. """
  148. def __init__(self, quant_config: GPTQMarlinConfig) -> None:
  149. self.quant_config = quant_config
  150. def create_weights(
  151. self,
  152. layer: torch.nn.Module,
  153. input_size_per_partition: int,
  154. output_partition_sizes: List[int],
  155. input_size: int,
  156. output_size: int,
  157. params_dtype: torch.dtype,
  158. **extra_weight_attrs,
  159. ) -> None:
  160. del output_size
  161. # Normalize group_size
  162. if self.quant_config.group_size != -1:
  163. group_size = self.quant_config.group_size
  164. else:
  165. group_size = input_size
  166. # Validate dtype
  167. if params_dtype not in [torch.float16, torch.bfloat16]:
  168. raise ValueError(f"The params dtype must be float16 "
  169. f"or bfloat16, but got {params_dtype}")
  170. # Validate output_size_per_partition
  171. output_size_per_partition = sum(output_partition_sizes)
  172. if output_size_per_partition % self.quant_config.min_thread_n != 0:
  173. raise ValueError(
  174. f"Weight output_size_per_partition = "
  175. f"{output_size_per_partition} is not divisible by "
  176. f" min_thread_n = {self.quant_config.min_thread_n}.")
  177. # Validate input_size_per_partition
  178. if input_size_per_partition % self.quant_config.min_thread_k != 0:
  179. raise ValueError(
  180. f"Weight input_size_per_partition = "
  181. f"{input_size_per_partition} is not divisible "
  182. f"by min_thread_k = {self.quant_config.min_thread_k}.")
  183. if (group_size < input_size
  184. and input_size_per_partition % group_size != 0):
  185. raise ValueError(
  186. f"Weight input_size_per_partition = {input_size_per_partition}"
  187. f" is not divisible by group_size = {group_size}.")
  188. # Detect sharding of scales/zp
  189. # By default, no sharding over "input dim"
  190. scales_and_zp_size = input_size // group_size
  191. scales_and_zp_input_dim = None
  192. if self.quant_config.desc_act:
  193. # Act-order case
  194. assert self.quant_config.group_size != -1
  195. is_k_full = input_size_per_partition == input_size
  196. else:
  197. # No act-order case
  198. # K is always full due to full alignment with
  199. # group-size and shard of scales/zp
  200. is_k_full = True
  201. # If this is a row-parallel case, then shard scales/zp
  202. if (input_size != input_size_per_partition
  203. and self.quant_config.group_size != -1):
  204. scales_and_zp_size = input_size_per_partition // group_size
  205. scales_and_zp_input_dim = 0
  206. # Init buffers
  207. # Quantized weights
  208. qweight = Parameter(
  209. torch.empty(
  210. input_size_per_partition // self.quant_config.pack_factor,
  211. output_size_per_partition,
  212. dtype=torch.int32,
  213. ),
  214. requires_grad=False,
  215. )
  216. set_weight_attrs(
  217. qweight, {
  218. **extra_weight_attrs,
  219. "input_dim": 0,
  220. "output_dim": 1,
  221. "packed_dim": 0,
  222. "pack_factor": self.quant_config.pack_factor,
  223. })
  224. # Activation order
  225. g_idx = Parameter(
  226. torch.empty(
  227. input_size_per_partition,
  228. dtype=torch.int32,
  229. ),
  230. requires_grad=False,
  231. )
  232. # Ignore warning from fused linear layers such as QKVParallelLinear.
  233. set_weight_attrs(g_idx, {
  234. **extra_weight_attrs, "input_dim": 0,
  235. "ignore_warning": True
  236. })
  237. g_idx_sort_indices = torch.empty(
  238. g_idx.shape,
  239. dtype=torch.int32,
  240. )
  241. # Scales
  242. scales = Parameter(
  243. torch.empty(
  244. scales_and_zp_size,
  245. output_size_per_partition,
  246. dtype=params_dtype,
  247. ),
  248. requires_grad=False,
  249. )
  250. set_weight_attrs(
  251. scales, {
  252. **extra_weight_attrs,
  253. "input_dim": scales_and_zp_input_dim,
  254. "output_dim": 1,
  255. })
  256. # Quantized zero-points
  257. qzeros = Parameter(
  258. torch.empty(scales_and_zp_size,
  259. output_size_per_partition //
  260. self.quant_config.pack_factor,
  261. dtype=torch.int32,
  262. device="meta"),
  263. requires_grad=False,
  264. )
  265. set_weight_attrs(
  266. qzeros, {
  267. **extra_weight_attrs,
  268. "input_dim": scales_and_zp_input_dim,
  269. "output_dim": 1,
  270. "packed_dim": 1,
  271. "pack_factor": self.quant_config.pack_factor,
  272. })
  273. # Allocate marlin workspace
  274. max_workspace_size = (
  275. output_size_per_partition //
  276. self.quant_config.min_thread_n) * self.quant_config.max_parallel
  277. workspace = torch.zeros(max_workspace_size,
  278. dtype=torch.int,
  279. requires_grad=False)
  280. layer.register_parameter("qweight", qweight)
  281. layer.register_parameter("g_idx", g_idx)
  282. layer.register_parameter("scales", scales)
  283. layer.register_parameter("qzeros", qzeros)
  284. layer.g_idx_sort_indices = g_idx_sort_indices
  285. layer.workspace = workspace
  286. layer.input_size_per_partition = input_size_per_partition
  287. layer.output_size_per_partition = output_size_per_partition
  288. layer.input_size = input_size
  289. layer.is_k_full = is_k_full
  290. layer.marlin_state = GPTQMarlinState.REPACK
  291. def apply(
  292. self,
  293. layer: torch.nn.Module,
  294. x: torch.Tensor,
  295. bias: Optional[torch.Tensor] = None,
  296. ) -> torch.Tensor:
  297. reshaped_x = x.reshape(-1, x.shape[-1])
  298. size_m = reshaped_x.shape[0]
  299. part_size_n = layer.output_size_per_partition
  300. part_size_k = layer.input_size_per_partition
  301. full_size_k = layer.input_size
  302. out_shape = x.shape[:-1] + (part_size_n, )
  303. if layer.marlin_state == GPTQMarlinState.REPACK:
  304. layer.marlin_state = GPTQMarlinState.READY
  305. # Newly generated tensors need to replace existing tensors that are
  306. # already registered as parameters by Aphrodite (and won't be
  307. # freed)
  308. def replace_tensor(name, new_t):
  309. # It is important to use resize_() here since it ensures
  310. # the same buffer is reused
  311. getattr(layer, name).resize_(new_t.shape)
  312. getattr(layer, name).copy_(new_t)
  313. del new_t
  314. cur_device = layer.qweight.device
  315. # Process act_order
  316. if self.quant_config.desc_act:
  317. # Get sorting based on g_idx
  318. g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
  319. sorted_g_idx = layer.g_idx[g_idx_sort_indices]
  320. replace_tensor("g_idx", sorted_g_idx)
  321. replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
  322. else:
  323. # Reset g_idx related tensors
  324. layer.g_idx = Parameter(torch.empty(0,
  325. dtype=torch.int,
  326. device=cur_device),
  327. requires_grad=False)
  328. layer.g_idx_sort_indices = Parameter(torch.empty(
  329. 0, dtype=torch.int, device=cur_device),
  330. requires_grad=False)
  331. # Repack weights
  332. marlin_qweight = ops.gptq_marlin_repack(
  333. layer.qweight,
  334. layer.g_idx_sort_indices,
  335. part_size_k,
  336. part_size_n,
  337. self.quant_config.weight_bits,
  338. )
  339. replace_tensor("qweight", marlin_qweight)
  340. # Permute scales
  341. scales_size_k = part_size_k
  342. scales_size_n = part_size_n
  343. if self.quant_config.desc_act:
  344. scales_size_k = full_size_k
  345. marlin_scales = marlin_permute_scales(
  346. layer.scales,
  347. scales_size_k,
  348. scales_size_n,
  349. self.quant_config.group_size,
  350. self.quant_config.weight_bits,
  351. )
  352. replace_tensor("scales", marlin_scales)
  353. output = ops.gptq_marlin_gemm(
  354. reshaped_x,
  355. layer.qweight,
  356. layer.scales,
  357. layer.g_idx,
  358. layer.g_idx_sort_indices,
  359. layer.workspace,
  360. self.quant_config.weight_bits,
  361. size_m,
  362. part_size_n,
  363. part_size_k,
  364. layer.is_k_full,
  365. )
  366. if bias is not None:
  367. output.add_(bias) # In-place add
  368. return output.reshape(out_shape)