gptq_marlin_24.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. from typing import Any, Dict, List, Optional
  2. from contextlib import suppress
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from loguru import logger
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.quantization.base_config import (QuantizationConfig)
  8. from aphrodite.modeling.utils import set_weight_attrs
  9. HAS_QUANTS = False
  10. with suppress(ImportError):
  11. from aphrodite._quant_C import quant_ops as ops
  12. HAS_QUANTS = True
  13. GPTQ_MARLIN_24_TILE = 16
  14. GPTQ_MARLIN_24_MIN_THREAD_N = 128
  15. GPTQ_MARLIN_24_MIN_THREAD_K = 128
  16. GPTQ_MARLIN_24_MAX_PARALLEL = 64
  17. GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
  18. GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
  19. GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
  20. class GPTQMarlin24Config(QuantizationConfig):
  21. """Config class for Marlin24.
  22. """
  23. def __init__(
  24. self,
  25. weight_bits: int,
  26. group_size: int,
  27. ) -> None:
  28. self.weight_bits = weight_bits
  29. self.group_size = group_size
  30. # Verify
  31. if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
  32. raise ValueError(
  33. f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
  34. f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
  35. "are supported.")
  36. if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
  37. raise ValueError(
  38. f"Marlin_24 does not support group_size = {self.group_size}. "
  39. f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
  40. "are supported.")
  41. # 4 Bits packed into 32 bit datatype.
  42. self.pack_factor = 32 // self.weight_bits
  43. # Tile size used by marlin kernels.
  44. self.tile_size = 16
  45. # Min out_features dim
  46. self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
  47. # Min in_features dim
  48. self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
  49. # Max parallel problems to solve at once (improves large
  50. # batch performance)
  51. self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
  52. # Permutation length used by the marlin kernels.
  53. self.perm_len = 1024
  54. def __repr__(self) -> str:
  55. return "Marlin24Config(weight_bits={}, group_size={})".format(
  56. self.weight_bits, self.group_size)
  57. @classmethod
  58. def get_name(cls) -> str:
  59. return "gptq_marlin_24"
  60. @classmethod
  61. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  62. return [torch.half]
  63. @classmethod
  64. # Need to figure it out
  65. def get_min_capability(cls) -> int:
  66. return 80
  67. @classmethod
  68. def get_config_filenames(cls) -> List[str]:
  69. return ["quantize_config.json"]
  70. @classmethod
  71. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config":
  72. weight_bits = cls.get_from_keys(config, ["bits"])
  73. group_size = cls.get_from_keys(config, ["group_size"])
  74. return cls(weight_bits, group_size)
  75. @classmethod
  76. def override_quantization_method(cls, hf_quant_cfg,
  77. user_quant) -> Optional[str]:
  78. is_marlin_24_format = (
  79. hf_quant_cfg.get("checkpoint_format") == "marlin_24")
  80. is_valid_user_quant = (user_quant is None or user_quant == "gptq"
  81. or user_quant == "gptq_marlin_24")
  82. if is_marlin_24_format and is_valid_user_quant:
  83. msg = ("The model is serialized in {} format. "
  84. "Using {} kernel.".format(cls.get_name(), cls.get_name()))
  85. logger.info(msg)
  86. return cls.get_name()
  87. return None
  88. def get_quant_method(
  89. self,
  90. layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]:
  91. if isinstance(layer, LinearBase):
  92. return GPTQMarlin24LinearMethod(self)
  93. return None
  94. def get_scaled_act_names(self) -> List[str]:
  95. return []
  96. class GPTQMarlin24LinearMethod(LinearMethodBase):
  97. """Linear method for Marlin24.
  98. Args:
  99. quant_config: The Marlin24 quantization config.
  100. """
  101. def __init__(self, quant_config: GPTQMarlin24Config):
  102. if not HAS_QUANTS:
  103. raise ImportError("Could not find the quantization kernels.")
  104. self.quant_config = quant_config
  105. def create_weights(
  106. self,
  107. layer: torch.nn.Module,
  108. input_size_per_partition: int,
  109. output_partition_sizes: List[int],
  110. input_size: int,
  111. output_size: int,
  112. params_dtype: torch.dtype,
  113. **extra_weight_attrs,
  114. ):
  115. del output_size # Unused.
  116. if params_dtype != torch.float16:
  117. raise ValueError(
  118. f"The params dtype must be float16, but got {params_dtype}")
  119. # Validate output_size_per_partition
  120. output_size_per_partition = sum(output_partition_sizes)
  121. if output_size_per_partition % self.quant_config.min_n_threads != 0:
  122. raise ValueError(
  123. f"Weight output_size_per_partition = "
  124. f"{output_size_per_partition} is not divisible by "
  125. f"min_n_threads = {self.quant_config.min_n_threads}.")
  126. if output_size_per_partition % self.quant_config.pack_factor != 0:
  127. raise ValueError(
  128. f"Weight output_size_per_partition = "
  129. f"{output_size_per_partition} is not divisible by "
  130. f"pack_factor = {self.quant_config.pack_factor}.")
  131. # Validate input_size_per_partition
  132. if input_size_per_partition % self.quant_config.min_k_threads != 0:
  133. raise ValueError(
  134. f"Weight input_size_per_partition = "
  135. f"{input_size_per_partition} is not divisible by "
  136. f"min_k_threads = {self.quant_config.min_k_threads}.")
  137. if (self.quant_config.group_size != -1 and
  138. input_size_per_partition % self.quant_config.group_size != 0):
  139. raise ValueError(f"Weight input_size_per_partition = "
  140. f"{input_size_per_partition} is not divisible by "
  141. f"group_size = {self.quant_config.group_size}.")
  142. # Check that we have at least 4 tiles horizontally in the shard
  143. num_tiles_per_perm = self.quant_config.perm_len // (
  144. self.quant_config.tile_size**2)
  145. if output_size_per_partition % num_tiles_per_perm != 0:
  146. raise ValueError(
  147. "Each permutation group must reside on the same gpu")
  148. # Quantized 4Bit weights packed into Int32.
  149. qweight = Parameter(
  150. torch.empty(
  151. input_size_per_partition // self.quant_config.tile_size // 2,
  152. output_size_per_partition * self.quant_config.tile_size //
  153. self.quant_config.pack_factor,
  154. device="cuda",
  155. dtype=torch.int32,
  156. ),
  157. requires_grad=False,
  158. )
  159. set_weight_attrs(
  160. qweight,
  161. {
  162. "input_dim": 0,
  163. "output_dim": 1,
  164. "packed_dim": 1,
  165. "pack_factor": self.quant_config.pack_factor,
  166. "marlin_tile_size": self.quant_config.tile_size,
  167. },
  168. )
  169. # Meta
  170. meta = Parameter(
  171. torch.empty(
  172. input_size_per_partition // 8 // 2 // 2,
  173. output_size_per_partition * 2,
  174. device="cuda",
  175. dtype=torch.int16,
  176. ),
  177. requires_grad=False,
  178. )
  179. set_weight_attrs(
  180. meta,
  181. {
  182. "input_dim": 0,
  183. "packed_dim": 1,
  184. "pack_factor": 1,
  185. "output_dim": 1,
  186. "marlin_tile_size": 2,
  187. },
  188. )
  189. # Determine if channelwise or not
  190. input_groups = (1 if self.quant_config.group_size == -1 else
  191. input_size_per_partition //
  192. self.quant_config.group_size)
  193. scales = Parameter(
  194. torch.empty(
  195. input_groups,
  196. output_size_per_partition,
  197. device="cuda",
  198. dtype=params_dtype,
  199. ),
  200. requires_grad=False,
  201. )
  202. set_weight_attrs(
  203. scales,
  204. {
  205. "input_dim": None if input_groups == 1 else 0,
  206. "output_dim": 1,
  207. },
  208. )
  209. # Allocate workspace (Used for internal locking mechanism)
  210. max_workspace_size = (
  211. output_size_per_partition //
  212. self.quant_config.min_n_threads) * self.quant_config.max_parallel
  213. workspace = Parameter(torch.zeros(max_workspace_size,
  214. device="cuda",
  215. dtype=torch.int),
  216. requires_grad=False)
  217. layer.register_parameter("B_24", qweight)
  218. set_weight_attrs(qweight, extra_weight_attrs)
  219. layer.register_parameter("B_meta", meta)
  220. set_weight_attrs(meta, extra_weight_attrs)
  221. layer.register_parameter("s", scales)
  222. set_weight_attrs(scales, extra_weight_attrs)
  223. layer.register_parameter("workspace", workspace)
  224. set_weight_attrs(workspace, extra_weight_attrs)
  225. def apply(
  226. self,
  227. layer: torch.nn.Module,
  228. x: torch.Tensor,
  229. bias: Optional[torch.Tensor] = None,
  230. ) -> torch.Tensor:
  231. qweight = layer.B_24
  232. meta = layer.B_meta
  233. scales = layer.s
  234. workspace = layer.workspace
  235. x_2d = x.view(-1, x.shape[-1])
  236. size_m = x_2d.shape[0]
  237. size_k = x_2d.shape[1]
  238. size_n = scales.shape[1]
  239. output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
  240. workspace,
  241. self.quant_config.weight_bits,
  242. size_m, size_n, size_k)
  243. output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
  244. if bias is not None:
  245. output.add_(bias) # In-place add
  246. return output