gptq_marlin.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. import enum
  2. from contextlib import suppress
  3. from enum import Enum
  4. from typing import Any, Dict, List, Optional
  5. import numpy
  6. import torch
  7. from torch.nn.parameter import Parameter
  8. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  9. set_weight_attrs)
  10. from aphrodite.quantization.base_config import QuantizationConfig
  11. HAS_QUANTS = False
  12. with suppress(ImportError):
  13. from aphrodite._quant_C import quant_ops as ops
  14. HAS_QUANTS = True
  15. GPTQ_MARLIN_TILE = 16
  16. GPTQ_MARLIN_MIN_THREAD_N = 64
  17. GPTQ_MARLIN_MIN_THREAD_K = 128
  18. GPTQ_MARLIN_MAX_PARALLEL = 16
  19. GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4]
  20. GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  21. GPTQ_MARLIN_SUPPORTED_SYM = [True]
  22. # Precompute permutations for Marlin weight and scale shuffling
  23. #
  24. # Marlin works on [16,64] tiles. The goal of the permutations
  25. # is to reorder the weight data so that it is compatible
  26. # with the tensor-core format that is described here:
  27. # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
  28. #
  29. # As a result of this reordering, the vector loads inside the
  30. # kernel will get the data as it is needed for tensor-core
  31. # (without the need to use ldmatrix instructions)
  32. def _get_perms():
  33. perm = []
  34. for i in range(32):
  35. perm1 = []
  36. col = i // 4
  37. for block in [0, 1]:
  38. for row in [
  39. 2 * (i % 4),
  40. 2 * (i % 4) + 1,
  41. 2 * (i % 4 + 4),
  42. 2 * (i % 4 + 4) + 1,
  43. ]:
  44. perm1.append(16 * row + col + 8 * block)
  45. for j in range(4):
  46. perm.extend([p + 256 * j for p in perm1])
  47. perm = numpy.array(perm)
  48. interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
  49. perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore
  50. perm = torch.from_numpy(perm)
  51. scale_perm = []
  52. for i in range(8):
  53. scale_perm.extend([i + 8 * j for j in range(8)])
  54. scale_perm_single = []
  55. for i in range(4):
  56. scale_perm_single.extend(
  57. [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
  58. return perm, scale_perm, scale_perm_single
  59. _perm, _scale_perm, _scale_perm_single = _get_perms()
  60. def get_pack_factor(num_bits):
  61. assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, (
  62. f"Unsupported num_bits = {num_bits}")
  63. return 32 // num_bits
  64. def marlin_permute_scales(s, size_k, size_n, group_size):
  65. if group_size < size_k and group_size != -1:
  66. s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
  67. else:
  68. s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
  69. s = s.reshape((-1, size_n)).contiguous()
  70. return s
  71. class GPTQMarlinConfig(QuantizationConfig):
  72. """Config class for GPTQ Marlin"""
  73. def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
  74. is_sym: bool) -> None:
  75. if desc_act and group_size == -1:
  76. # In this case, act_order == True is the same as act_order == False
  77. # (since we have only one group per output channel)
  78. desc_act = False
  79. self.weight_bits = weight_bits
  80. self.group_size = group_size
  81. self.desc_act = desc_act
  82. self.is_sym = is_sym
  83. # Verify
  84. if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
  85. raise ValueError(
  86. f"Marlin does not support weight_bits = {self.weight_bits}. "
  87. f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
  88. "are supported.")
  89. if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
  90. raise ValueError(
  91. f"Marlin does not support group_size = {self.group_size}. "
  92. f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
  93. "are supported.")
  94. if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
  95. raise ValueError(
  96. f"Marlin does not support is_sym = {self.is_sym}. "
  97. f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
  98. # Init
  99. self.pack_factor = get_pack_factor(weight_bits)
  100. self.tile_size = GPTQ_MARLIN_TILE
  101. self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
  102. self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
  103. self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
  104. def __repr__(self) -> str:
  105. return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
  106. f"group_size={self.group_size}, "
  107. f"desc_act={self.desc_act})")
  108. @classmethod
  109. def get_name(cls) -> str:
  110. return "gptq_marlin"
  111. @classmethod
  112. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  113. return [torch.half]
  114. @classmethod
  115. def get_min_capability(cls) -> int:
  116. return 80
  117. @classmethod
  118. def get_config_filenames(cls) -> List[str]:
  119. return ["quantize_config.json"]
  120. @classmethod
  121. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
  122. weight_bits = cls.get_from_keys(config, ["bits"])
  123. group_size = cls.get_from_keys(config, ["group_size"])
  124. desc_act = cls.get_from_keys(config, ["desc_act"])
  125. is_sym = cls.get_from_keys(config, ["sym"])
  126. return cls(weight_bits, group_size, desc_act, is_sym)
  127. def get_quant_method(
  128. self,
  129. layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
  130. if isinstance(layer, LinearBase):
  131. return GPTQMarlinLinearMethod(self)
  132. return None
  133. def get_scaled_act_names(self) -> List[str]:
  134. return []
  135. @classmethod
  136. def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
  137. # Extract data from quant config.
  138. num_bits = quant_config.get("bits", None)
  139. group_size = quant_config.get("group_size", None)
  140. sym = quant_config.get("sym", None)
  141. desc_act = quant_config.get("desc_act", None)
  142. # If we cannot find the info needed in the config, cannot convert.
  143. if (num_bits is None or group_size is None or sym is None
  144. or desc_act is None):
  145. return False
  146. # If the capability of the device is too low, cannot convert.
  147. major, minor = torch.cuda.get_device_capability()
  148. device_capability = major * 10 + minor
  149. if device_capability < cls.get_min_capability():
  150. return False
  151. # Otherwise, can convert if model satisfies marlin constraints.
  152. return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
  153. and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
  154. and sym in GPTQ_MARLIN_SUPPORTED_SYM)
  155. class GPTQMarlinState(Enum):
  156. REPACK = enum.auto()
  157. READY = enum.auto()
  158. class GPTQMarlinLinearMethod(LinearMethodBase):
  159. """Linear method for GPTQ Marlin.
  160. Args:
  161. quant_config: The GPTQ Marlin quantization config.
  162. """
  163. def __init__(self, quant_config: GPTQMarlinConfig) -> None:
  164. if not HAS_QUANTS:
  165. raise ImportError("Could not find the quantization kernels.")
  166. self.quant_config = quant_config
  167. def create_weights(
  168. self,
  169. layer: torch.nn.Module,
  170. input_size_per_partition: int,
  171. output_partition_sizes: List[int],
  172. input_size: int,
  173. output_size: int,
  174. params_dtype: torch.dtype,
  175. **extra_weight_attrs,
  176. ) -> None:
  177. del output_size
  178. # Normalize group_size
  179. if self.quant_config.group_size != -1:
  180. group_size = self.quant_config.group_size
  181. else:
  182. group_size = input_size
  183. # Validate dtype
  184. if params_dtype != torch.float16:
  185. raise ValueError(
  186. f"The params dtype must be float16, but got {params_dtype}")
  187. # Validate output_size_per_partition
  188. output_size_per_partition = sum(output_partition_sizes)
  189. if output_size_per_partition % self.quant_config.min_thread_n != 0:
  190. raise ValueError(
  191. f"Weight output_size_per_partition = "
  192. f"{output_size_per_partition} is not divisible by "
  193. f" min_thread_n = {self.quant_config.min_thread_n}.")
  194. # Validate input_size_per_partition
  195. if input_size_per_partition % self.quant_config.min_thread_k != 0:
  196. raise ValueError(
  197. f"Weight input_size_per_partition = "
  198. f"{input_size_per_partition} is not divisible "
  199. f"by min_thread_k = {self.quant_config.min_thread_k}.")
  200. if (group_size < input_size
  201. and input_size_per_partition % group_size != 0):
  202. raise ValueError(
  203. f"Weight input_size_per_partition = {input_size_per_partition}"
  204. f" is not divisible by group_size = {group_size}.")
  205. # Detect sharding of scales/zp
  206. # By default, no sharding over "input dim"
  207. scales_and_zp_size = input_size // group_size
  208. scales_and_zp_input_dim = None
  209. if self.quant_config.desc_act:
  210. # Act-order case
  211. assert self.quant_config.group_size != -1
  212. is_k_full = input_size_per_partition == input_size
  213. else:
  214. # No act-order case
  215. # K is always full due to full alignment with
  216. # group-size and shard of scales/zp
  217. is_k_full = True
  218. # If this is a row-parallel case, then shard scales/zp
  219. if (input_size != input_size_per_partition
  220. and self.quant_config.group_size != -1):
  221. scales_and_zp_size = input_size_per_partition // group_size
  222. scales_and_zp_input_dim = 0
  223. # Init buffers
  224. # Quantized weights
  225. qweight = Parameter(
  226. torch.empty(
  227. input_size_per_partition // self.quant_config.pack_factor,
  228. output_size_per_partition,
  229. dtype=torch.int32,
  230. ),
  231. requires_grad=False,
  232. )
  233. set_weight_attrs(
  234. qweight, {
  235. **extra_weight_attrs,
  236. "input_dim": 0,
  237. "output_dim": 1,
  238. "packed_dim": 0,
  239. "pack_factor": self.quant_config.pack_factor,
  240. })
  241. # Activation order
  242. g_idx = Parameter(
  243. torch.empty(
  244. input_size_per_partition,
  245. dtype=torch.int32,
  246. ),
  247. requires_grad=False,
  248. )
  249. # Ignore warning from fused linear layers such as QKVParallelLinear.
  250. set_weight_attrs(g_idx, {
  251. **extra_weight_attrs, "input_dim": 0,
  252. "ignore_warning": True
  253. })
  254. g_idx_sort_indices = Parameter(
  255. torch.empty(
  256. g_idx.shape,
  257. dtype=torch.int32,
  258. ),
  259. requires_grad=False,
  260. )
  261. set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)
  262. # Scales
  263. scales = Parameter(
  264. torch.empty(
  265. scales_and_zp_size,
  266. output_size_per_partition,
  267. dtype=params_dtype,
  268. ),
  269. requires_grad=False,
  270. )
  271. set_weight_attrs(
  272. scales, {
  273. **extra_weight_attrs,
  274. "input_dim": scales_and_zp_input_dim,
  275. "output_dim": 1,
  276. })
  277. # Quantized zero-points
  278. qzeros = Parameter(
  279. torch.empty(scales_and_zp_size,
  280. output_size_per_partition //
  281. self.quant_config.pack_factor,
  282. dtype=torch.int32,
  283. device="meta"),
  284. requires_grad=False,
  285. )
  286. set_weight_attrs(
  287. qzeros, {
  288. **extra_weight_attrs,
  289. "input_dim": scales_and_zp_input_dim,
  290. "output_dim": 1,
  291. "packed_dim": 1,
  292. "pack_factor": self.quant_config.pack_factor,
  293. })
  294. # Allocate marlin workspace
  295. max_workspace_size = (
  296. output_size_per_partition //
  297. self.quant_config.min_thread_n) * self.quant_config.max_parallel
  298. workspace = torch.zeros(max_workspace_size,
  299. dtype=torch.int,
  300. requires_grad=False)
  301. layer.register_parameter("qweight", qweight)
  302. layer.register_parameter("g_idx", g_idx)
  303. layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
  304. layer.register_parameter("scales", scales)
  305. layer.register_parameter("qzeros", qzeros)
  306. layer.workspace = workspace
  307. layer.input_size_per_partition = input_size_per_partition
  308. layer.output_size_per_partition = output_size_per_partition
  309. layer.input_size = input_size
  310. layer.is_k_full = is_k_full
  311. layer.marlin_state = GPTQMarlinState.REPACK
  312. def apply(
  313. self,
  314. layer: torch.nn.Module,
  315. x: torch.Tensor,
  316. bias: Optional[torch.Tensor] = None,
  317. ) -> torch.Tensor:
  318. reshaped_x = x.reshape(-1, x.shape[-1])
  319. size_m = reshaped_x.shape[0]
  320. part_size_n = layer.output_size_per_partition
  321. part_size_k = layer.input_size_per_partition
  322. full_size_k = layer.input_size
  323. out_shape = x.shape[:-1] + (part_size_n, )
  324. if layer.marlin_state == GPTQMarlinState.REPACK:
  325. layer.marlin_state = GPTQMarlinState.READY
  326. # Newly generated tensors need to replace existing tensors that are
  327. # already registered as parameters by vLLM (and won't be freed)
  328. def replace_tensor(name, new_t):
  329. # It is important to use resize_() here since it ensures
  330. # the same buffer is reused
  331. getattr(layer, name).resize_(new_t.shape)
  332. getattr(layer, name).copy_(new_t)
  333. del new_t
  334. cur_device = layer.qweight.device
  335. # Process act_order
  336. if self.quant_config.desc_act:
  337. # Get sorting based on g_idx
  338. g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
  339. sorted_g_idx = layer.g_idx[g_idx_sort_indices]
  340. replace_tensor("g_idx", sorted_g_idx)
  341. replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
  342. else:
  343. # Reset g_idx related tensors
  344. layer.g_idx = Parameter(torch.empty(0,
  345. dtype=torch.int,
  346. device=cur_device),
  347. requires_grad=False)
  348. layer.g_idx_sort_indices = Parameter(torch.empty(
  349. 0, dtype=torch.int, device=cur_device),
  350. requires_grad=False)
  351. # Repack weights
  352. marlin_qweight = ops.gptq_marlin_repack(
  353. layer.qweight,
  354. layer.g_idx_sort_indices,
  355. part_size_k,
  356. part_size_n,
  357. )
  358. replace_tensor("qweight", marlin_qweight)
  359. # Permute scales
  360. scales_size_k = part_size_k
  361. scales_size_n = part_size_n
  362. if self.quant_config.desc_act:
  363. scales_size_k = full_size_k
  364. marlin_scales = marlin_permute_scales(layer.scales, scales_size_k,
  365. scales_size_n,
  366. self.quant_config.group_size)
  367. replace_tensor("scales", marlin_scales)
  368. output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales,
  369. layer.g_idx, layer.g_idx_sort_indices,
  370. layer.workspace, size_m, part_size_n,
  371. part_size_k, layer.is_k_full)
  372. if bias is not None:
  373. output.add_(bias) # In-place add
  374. return output.reshape(out_shape)