1
0

gptq_marlin_24.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn.parameter import Parameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.modeling.utils import set_weight_attrs
  8. from aphrodite.quantization.base_config import QuantizationConfig
  9. GPTQ_MARLIN_24_TILE = 16
  10. GPTQ_MARLIN_24_MIN_THREAD_N = 128
  11. GPTQ_MARLIN_24_MIN_THREAD_K = 128
  12. GPTQ_MARLIN_24_MAX_PARALLEL = 64
  13. GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
  14. GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
  15. GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
  16. class GPTQMarlin24Config(QuantizationConfig):
  17. """Config class for Marlin24.
  18. """
  19. def __init__(
  20. self,
  21. weight_bits: int,
  22. group_size: int,
  23. ) -> None:
  24. self.weight_bits = weight_bits
  25. self.group_size = group_size
  26. # Verify
  27. if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
  28. raise ValueError(
  29. f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
  30. f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
  31. "are supported.")
  32. if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
  33. raise ValueError(
  34. f"Marlin_24 does not support group_size = {self.group_size}. "
  35. f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
  36. "are supported.")
  37. # 4 Bits packed into 32 bit datatype.
  38. self.pack_factor = 32 // self.weight_bits
  39. # Tile size used by marlin kernels.
  40. self.tile_size = 16
  41. # Min out_features dim
  42. self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
  43. # Min in_features dim
  44. self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
  45. # Max parallel problems to solve at once (improves large
  46. # batch performance)
  47. self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
  48. # Permutation length used by the marlin kernels.
  49. self.perm_len = 1024
  50. def __repr__(self) -> str:
  51. return "Marlin24Config(weight_bits={}, group_size={})".format(
  52. self.weight_bits, self.group_size)
  53. @classmethod
  54. def get_name(cls) -> str:
  55. return "gptq_marlin_24"
  56. @classmethod
  57. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  58. return [torch.half]
  59. @classmethod
  60. # Need to figure it out
  61. def get_min_capability(cls) -> int:
  62. return 80
  63. @classmethod
  64. def get_config_filenames(cls) -> List[str]:
  65. return ["quantize_config.json"]
  66. @classmethod
  67. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config":
  68. weight_bits = cls.get_from_keys(config, ["bits"])
  69. group_size = cls.get_from_keys(config, ["group_size"])
  70. return cls(weight_bits, group_size)
  71. @classmethod
  72. def override_quantization_method(cls, hf_quant_cfg,
  73. user_quant) -> Optional[str]:
  74. is_marlin_24_format = (
  75. hf_quant_cfg.get("checkpoint_format") == "marlin_24")
  76. is_valid_user_quant = (user_quant is None or user_quant == "gptq"
  77. or user_quant == "gptq_marlin_24")
  78. if is_marlin_24_format and is_valid_user_quant:
  79. msg = ("The model is serialized in {} format. "
  80. "Using {} kernel.".format(cls.get_name(), cls.get_name()))
  81. logger.info(msg)
  82. return cls.get_name()
  83. return None
  84. def get_quant_method(
  85. self,
  86. layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]:
  87. if isinstance(layer, LinearBase):
  88. return GPTQMarlin24LinearMethod(self)
  89. return None
  90. def get_scaled_act_names(self) -> List[str]:
  91. return []
  92. class GPTQMarlin24LinearMethod(LinearMethodBase):
  93. """Linear method for Marlin24.
  94. Args:
  95. quant_config: The Marlin24 quantization config.
  96. """
  97. def __init__(self, quant_config: GPTQMarlin24Config):
  98. self.quant_config = quant_config
  99. def create_weights(
  100. self,
  101. layer: torch.nn.Module,
  102. input_size_per_partition: int,
  103. output_partition_sizes: List[int],
  104. input_size: int,
  105. output_size: int,
  106. params_dtype: torch.dtype,
  107. **extra_weight_attrs,
  108. ):
  109. del output_size # Unused.
  110. if params_dtype != torch.float16:
  111. raise ValueError(
  112. f"The params dtype must be float16, but got {params_dtype}")
  113. # Validate output_size_per_partition
  114. output_size_per_partition = sum(output_partition_sizes)
  115. if output_size_per_partition % self.quant_config.min_n_threads != 0:
  116. raise ValueError(
  117. f"Weight output_size_per_partition = "
  118. f"{output_size_per_partition} is not divisible by "
  119. f"min_n_threads = {self.quant_config.min_n_threads}.")
  120. if output_size_per_partition % self.quant_config.pack_factor != 0:
  121. raise ValueError(
  122. f"Weight output_size_per_partition = "
  123. f"{output_size_per_partition} is not divisible by "
  124. f"pack_factor = {self.quant_config.pack_factor}.")
  125. # Validate input_size_per_partition
  126. if input_size_per_partition % self.quant_config.min_k_threads != 0:
  127. raise ValueError(
  128. f"Weight input_size_per_partition = "
  129. f"{input_size_per_partition} is not divisible by "
  130. f"min_k_threads = {self.quant_config.min_k_threads}.")
  131. if (self.quant_config.group_size != -1 and
  132. input_size_per_partition % self.quant_config.group_size != 0):
  133. raise ValueError(f"Weight input_size_per_partition = "
  134. f"{input_size_per_partition} is not divisible by "
  135. f"group_size = {self.quant_config.group_size}.")
  136. # Check that we have at least 4 tiles horizontally in the shard
  137. num_tiles_per_perm = self.quant_config.perm_len // (
  138. self.quant_config.tile_size**2)
  139. if output_size_per_partition % num_tiles_per_perm != 0:
  140. raise ValueError(
  141. "Each permutation group must reside on the same gpu")
  142. # Quantized 4Bit weights packed into Int32.
  143. qweight = Parameter(
  144. torch.empty(
  145. input_size_per_partition // self.quant_config.tile_size // 2,
  146. output_size_per_partition * self.quant_config.tile_size //
  147. self.quant_config.pack_factor,
  148. device="cuda",
  149. dtype=torch.int32,
  150. ),
  151. requires_grad=False,
  152. )
  153. set_weight_attrs(
  154. qweight,
  155. {
  156. "input_dim": 0,
  157. "output_dim": 1,
  158. "packed_dim": 1,
  159. "pack_factor": self.quant_config.pack_factor,
  160. "marlin_tile_size": self.quant_config.tile_size,
  161. },
  162. )
  163. # Meta
  164. meta = Parameter(
  165. torch.empty(
  166. input_size_per_partition // 8 // 2 // 2,
  167. output_size_per_partition * 2,
  168. device="cuda",
  169. dtype=torch.int16,
  170. ),
  171. requires_grad=False,
  172. )
  173. set_weight_attrs(
  174. meta,
  175. {
  176. "input_dim": 0,
  177. "packed_dim": 1,
  178. "pack_factor": 1,
  179. "output_dim": 1,
  180. "marlin_tile_size": 2,
  181. },
  182. )
  183. # Determine if channelwise or not
  184. input_groups = (1 if self.quant_config.group_size == -1 else
  185. input_size_per_partition //
  186. self.quant_config.group_size)
  187. scales = Parameter(
  188. torch.empty(
  189. input_groups,
  190. output_size_per_partition,
  191. device="cuda",
  192. dtype=params_dtype,
  193. ),
  194. requires_grad=False,
  195. )
  196. set_weight_attrs(
  197. scales,
  198. {
  199. "input_dim": None if input_groups == 1 else 0,
  200. "output_dim": 1,
  201. },
  202. )
  203. # Allocate workspace (Used for internal locking mechanism)
  204. max_workspace_size = (
  205. output_size_per_partition //
  206. self.quant_config.min_n_threads) * self.quant_config.max_parallel
  207. workspace = Parameter(torch.zeros(max_workspace_size,
  208. device="cuda",
  209. dtype=torch.int),
  210. requires_grad=False)
  211. layer.register_parameter("B_24", qweight)
  212. set_weight_attrs(qweight, extra_weight_attrs)
  213. layer.register_parameter("B_meta", meta)
  214. set_weight_attrs(meta, extra_weight_attrs)
  215. layer.register_parameter("s", scales)
  216. set_weight_attrs(scales, extra_weight_attrs)
  217. layer.register_parameter("workspace", workspace)
  218. set_weight_attrs(workspace, extra_weight_attrs)
  219. def apply(
  220. self,
  221. layer: torch.nn.Module,
  222. x: torch.Tensor,
  223. bias: Optional[torch.Tensor] = None,
  224. ) -> torch.Tensor:
  225. qweight = layer.B_24
  226. meta = layer.B_meta
  227. scales = layer.s
  228. workspace = layer.workspace
  229. x_2d = x.view(-1, x.shape[-1])
  230. size_m = x_2d.shape[0]
  231. size_k = x_2d.shape[1]
  232. size_n = scales.shape[1]
  233. output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
  234. workspace,
  235. self.quant_config.weight_bits,
  236. size_m, size_n, size_k)
  237. output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
  238. if bias is not None:
  239. output.add_(bias) # In-place add
  240. return output