gptq_marlin_24.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. from typing import Any, Dict, List, Optional
  2. from contextlib import suppress
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from loguru import logger
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.quantization.base_config import (QuantizationConfig)
  8. from aphrodite.modeling.utils import set_weight_attrs
  9. HAS_QUANTS = False
  10. with suppress(ImportError):
  11. from aphrodite._quant_C import quant_ops as ops
  12. HAS_QUANTS = True
  13. class GPTQMarlin24Config(QuantizationConfig):
  14. """Config class for Marlin24.
  15. """
  16. def __init__(
  17. self,
  18. weight_bits: int,
  19. group_size: int,
  20. ) -> None:
  21. self.weight_bits = weight_bits
  22. self.group_size = group_size
  23. if self.weight_bits != 4 and self.weight_bits != 8:
  24. raise ValueError("weight_bits must be 4 or 8. Got = {}".format(
  25. self.weight_bits))
  26. if self.group_size != 128 and self.group_size != -1:
  27. raise ValueError(
  28. "Currently, only group size 128 and -1 (channelwise) "
  29. "is supported for Marlin24, but got group_size of "
  30. f"{self.group_size}")
  31. # 4 Bits packed into 32 bit datatype.
  32. self.pack_factor = 32 // self.weight_bits
  33. # Tile size used by marlin kernels.
  34. self.tile_size = 16
  35. # Min out_features dim
  36. self.min_n_threads = 128
  37. # Min in_features dim
  38. self.min_k_threads = 128
  39. # Max parallel problems to solve at once (improves large
  40. # batch performance)
  41. self.max_parallel = 16
  42. # Permutation length used by the marlin kernels.
  43. self.perm_len = 1024
  44. def __repr__(self) -> str:
  45. return "Marlin24Config(weight_bits={}, group_size={})".format(
  46. self.weight_bits, self.group_size)
  47. @classmethod
  48. def get_name(cls) -> str:
  49. return "gptq_marlin_24"
  50. @classmethod
  51. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  52. return [torch.half]
  53. @classmethod
  54. # Need to figure it out
  55. def get_min_capability(cls) -> int:
  56. return 80
  57. @classmethod
  58. def get_config_filenames(cls) -> List[str]:
  59. return ["quantize_config.json"]
  60. @classmethod
  61. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config":
  62. weight_bits = cls.get_from_keys(config, ["bits"])
  63. group_size = cls.get_from_keys(config, ["group_size"])
  64. return cls(weight_bits, group_size)
  65. @classmethod
  66. def override_quantization_method(cls, hf_quant_cfg,
  67. user_quant) -> Optional[str]:
  68. is_marlin_24_format = (
  69. hf_quant_cfg.get("checkpoint_format") == "marlin_24")
  70. is_valid_user_quant = (user_quant is None or user_quant == "gptq"
  71. or user_quant == "gptq_marlin_24")
  72. if is_marlin_24_format and is_valid_user_quant:
  73. msg = ("The model is serialized in {} format. "
  74. "Using {} kernel.".format(cls.get_name(), cls.get_name()))
  75. logger.info(msg)
  76. return cls.get_name()
  77. return None
  78. def get_quant_method(
  79. self,
  80. layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]:
  81. if isinstance(layer, LinearBase):
  82. return GPTQMarlin24LinearMethod(self)
  83. return None
  84. def get_scaled_act_names(self) -> List[str]:
  85. return []
  86. class GPTQMarlin24LinearMethod(LinearMethodBase):
  87. """Linear method for Marlin24.
  88. Args:
  89. quant_config: The Marlin24 quantization config.
  90. """
  91. def __init__(self, quant_config: GPTQMarlin24Config):
  92. if not HAS_QUANTS:
  93. raise ImportError("Could not find the quantization kernels.")
  94. self.quant_config = quant_config
  95. def create_weights(
  96. self,
  97. layer: torch.nn.Module,
  98. input_size_per_partition: int,
  99. output_partition_sizes: List[int],
  100. input_size: int,
  101. output_size: int,
  102. params_dtype: torch.dtype,
  103. **extra_weight_attrs,
  104. ):
  105. del output_size # Unused.
  106. if params_dtype != torch.float16:
  107. raise ValueError(
  108. f"The params dtype must be float16, but got {params_dtype}")
  109. # Validate output_size_per_partition
  110. output_size_per_partition = sum(output_partition_sizes)
  111. if output_size_per_partition % self.quant_config.min_n_threads != 0:
  112. raise ValueError(
  113. f"Weight output_size_per_partition = "
  114. f"{output_size_per_partition} is not divisible by "
  115. f"min_n_threads = {self.quant_config.min_n_threads}.")
  116. if output_size_per_partition % self.quant_config.pack_factor != 0:
  117. raise ValueError(
  118. f"Weight output_size_per_partition = "
  119. f"{output_size_per_partition} is not divisible by "
  120. f"pack_factor = {self.quant_config.pack_factor}.")
  121. # Validate input_size_per_partition
  122. if input_size_per_partition % self.quant_config.min_k_threads != 0:
  123. raise ValueError(
  124. f"Weight input_size_per_partition = "
  125. f"{input_size_per_partition} is not divisible by "
  126. f"min_k_threads = {self.quant_config.min_k_threads}.")
  127. if (self.quant_config.group_size != -1 and
  128. input_size_per_partition % self.quant_config.group_size != 0):
  129. raise ValueError(f"Weight input_size_per_partition = "
  130. f"{input_size_per_partition} is not divisible by "
  131. f"group_size = {self.quant_config.group_size}.")
  132. # Check that we have at least 4 tiles horizontally in the shard
  133. num_tiles_per_perm = self.quant_config.perm_len // (
  134. self.quant_config.tile_size**2)
  135. if output_size_per_partition % num_tiles_per_perm != 0:
  136. raise ValueError(
  137. "Each permutation group must reside on the same gpu")
  138. # Quantized 4Bit weights packed into Int32.
  139. qweight = Parameter(
  140. torch.empty(
  141. input_size_per_partition // self.quant_config.tile_size // 2,
  142. output_size_per_partition * self.quant_config.tile_size //
  143. self.quant_config.pack_factor,
  144. device="cuda",
  145. dtype=torch.int32,
  146. ),
  147. requires_grad=False,
  148. )
  149. set_weight_attrs(
  150. qweight,
  151. {
  152. "input_dim": 0,
  153. "output_dim": 1,
  154. "packed_dim": 1,
  155. "pack_factor": self.quant_config.pack_factor,
  156. "marlin_tile_size": self.quant_config.tile_size,
  157. },
  158. )
  159. # Meta
  160. meta = Parameter(
  161. torch.empty(
  162. input_size_per_partition // 8 // 2 // 2,
  163. output_size_per_partition * 2,
  164. device="cuda",
  165. dtype=torch.int16,
  166. ),
  167. requires_grad=False,
  168. )
  169. set_weight_attrs(
  170. meta,
  171. {
  172. "input_dim": 0,
  173. "packed_dim": 1,
  174. "pack_factor": 1,
  175. "output_dim": 1,
  176. "marlin_tile_size": 2,
  177. },
  178. )
  179. # Determine if channelwise or not
  180. input_groups = (1 if self.quant_config.group_size == -1 else
  181. input_size_per_partition //
  182. self.quant_config.group_size)
  183. scales = Parameter(
  184. torch.empty(
  185. input_groups,
  186. output_size_per_partition,
  187. device="cuda",
  188. dtype=params_dtype,
  189. ),
  190. requires_grad=False,
  191. )
  192. set_weight_attrs(
  193. scales,
  194. {
  195. "input_dim": None if input_groups == 1 else 0,
  196. "output_dim": 1,
  197. },
  198. )
  199. # Allocate workspace (Used for internal locking mechanism)
  200. max_workspace_size = (
  201. output_size_per_partition //
  202. self.quant_config.min_n_threads) * self.quant_config.max_parallel
  203. workspace = Parameter(torch.zeros(max_workspace_size,
  204. device="cuda",
  205. dtype=torch.int),
  206. requires_grad=False)
  207. layer.register_parameter("B_24", qweight)
  208. set_weight_attrs(qweight, extra_weight_attrs)
  209. layer.register_parameter("B_meta", meta)
  210. set_weight_attrs(meta, extra_weight_attrs)
  211. layer.register_parameter("s", scales)
  212. set_weight_attrs(scales, extra_weight_attrs)
  213. layer.register_parameter("workspace", workspace)
  214. set_weight_attrs(workspace, extra_weight_attrs)
  215. def apply(
  216. self,
  217. layer: torch.nn.Module,
  218. x: torch.Tensor,
  219. bias: Optional[torch.Tensor] = None,
  220. ) -> torch.Tensor:
  221. qweight = layer.B_24
  222. meta = layer.B_meta
  223. scales = layer.s
  224. workspace = layer.workspace
  225. x_2d = x.view(-1, x.shape[-1])
  226. size_m = x_2d.shape[0]
  227. size_k = x_2d.shape[1]
  228. size_n = scales.shape[1]
  229. output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
  230. workspace,
  231. self.quant_config.weight_bits,
  232. size_m, size_n, size_k)
  233. output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
  234. if bias is not None:
  235. output.add_(bias) # In-place add
  236. return output