gptq_marlin_24.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn.parameter import Parameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.modeling.utils import set_weight_attrs
  8. from aphrodite.quantization.base_config import QuantizationConfig
  9. from aphrodite.scalar_type import scalar_types
  10. GPTQ_MARLIN_24_TILE = 16
  11. GPTQ_MARLIN_24_MIN_THREAD_N = 128
  12. GPTQ_MARLIN_24_MIN_THREAD_K = 128
  13. GPTQ_MARLIN_24_MAX_PARALLEL = 64
  14. GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
  15. scalar_types.uint4b8, scalar_types.uint8b128
  16. ]
  17. GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
  18. class GPTQMarlin24Config(QuantizationConfig):
  19. """Config class for Marlin24.
  20. """
  21. def __init__(
  22. self,
  23. weight_bits: int,
  24. group_size: int,
  25. ) -> None:
  26. quant_type = {
  27. 4: scalar_types.uint4b8,
  28. 8: scalar_types.uint8b128,
  29. }.get(weight_bits)
  30. self.group_size = group_size
  31. # Verify
  32. if quant_type is None or \
  33. quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
  34. raise ValueError(
  35. f"Marlin_24 does not support quant_type = {quant_type}. "
  36. f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
  37. "are supported.")
  38. if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
  39. raise ValueError(
  40. f"Marlin_24 does not support group_size = {self.group_size}. "
  41. f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
  42. "are supported.")
  43. self.quant_type = quant_type
  44. # 4 Bits packed into 32 bit datatype.
  45. self.pack_factor = 32 // self.quant_type.size_bits
  46. # Tile size used by marlin kernels.
  47. self.tile_size = 16
  48. # Min out_features dim
  49. self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
  50. # Min in_features dim
  51. self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
  52. # Max parallel problems to solve at once (improves large
  53. # batch performance)
  54. self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
  55. # Permutation length used by the marlin kernels.
  56. self.perm_len = 1024
  57. def __repr__(self) -> str:
  58. return "Marlin24Config(quant_type={}, group_size={})".format(
  59. self.quant_type, self.group_size)
  60. @classmethod
  61. def get_name(cls) -> str:
  62. return "gptq_marlin_24"
  63. @classmethod
  64. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  65. return [torch.half]
  66. @classmethod
  67. # Need to figure it out
  68. def get_min_capability(cls) -> int:
  69. return 80
  70. @classmethod
  71. def get_config_filenames(cls) -> List[str]:
  72. return ["quantize_config.json"]
  73. @classmethod
  74. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config":
  75. weight_bits = cls.get_from_keys(config, ["bits"])
  76. group_size = cls.get_from_keys(config, ["group_size"])
  77. return cls(weight_bits, group_size)
  78. @classmethod
  79. def override_quantization_method(cls, hf_quant_cfg,
  80. user_quant) -> Optional[str]:
  81. is_marlin_24_format = (
  82. hf_quant_cfg.get("checkpoint_format") == "marlin_24")
  83. is_valid_user_quant = (user_quant is None or user_quant == "gptq"
  84. or user_quant == "gptq_marlin_24")
  85. if is_marlin_24_format and is_valid_user_quant:
  86. msg = ("The model is serialized in {} format. "
  87. "Using {} kernel.".format(cls.get_name(), cls.get_name()))
  88. logger.info(msg)
  89. return cls.get_name()
  90. return None
  91. def get_quant_method(self, layer: torch.nn.Module,
  92. prefix: str) -> Optional["GPTQMarlin24LinearMethod"]:
  93. if isinstance(layer, LinearBase):
  94. return GPTQMarlin24LinearMethod(self)
  95. return None
  96. def get_scaled_act_names(self) -> List[str]:
  97. return []
  98. class GPTQMarlin24LinearMethod(LinearMethodBase):
  99. """Linear method for Marlin24.
  100. Args:
  101. quant_config: The Marlin24 quantization config.
  102. """
  103. def __init__(self, quant_config: GPTQMarlin24Config):
  104. self.quant_config = quant_config
  105. def create_weights(
  106. self,
  107. layer: torch.nn.Module,
  108. input_size_per_partition: int,
  109. output_partition_sizes: List[int],
  110. input_size: int,
  111. output_size: int,
  112. params_dtype: torch.dtype,
  113. **extra_weight_attrs,
  114. ):
  115. del output_size # Unused.
  116. if params_dtype != torch.float16:
  117. raise ValueError(
  118. f"The params dtype must be float16, but got {params_dtype}")
  119. # Validate output_size_per_partition
  120. output_size_per_partition = sum(output_partition_sizes)
  121. if output_size_per_partition % self.quant_config.min_n_threads != 0:
  122. raise ValueError(
  123. f"Weight output_size_per_partition = "
  124. f"{output_size_per_partition} is not divisible by "
  125. f"min_n_threads = {self.quant_config.min_n_threads}.")
  126. if output_size_per_partition % self.quant_config.pack_factor != 0:
  127. raise ValueError(
  128. f"Weight output_size_per_partition = "
  129. f"{output_size_per_partition} is not divisible by "
  130. f"pack_factor = {self.quant_config.pack_factor}.")
  131. # Validate input_size_per_partition
  132. if input_size_per_partition % self.quant_config.min_k_threads != 0:
  133. raise ValueError(
  134. f"Weight input_size_per_partition = "
  135. f"{input_size_per_partition} is not divisible by "
  136. f"min_k_threads = {self.quant_config.min_k_threads}.")
  137. if (self.quant_config.group_size != -1 and
  138. input_size_per_partition % self.quant_config.group_size != 0):
  139. raise ValueError(f"Weight input_size_per_partition = "
  140. f"{input_size_per_partition} is not divisible by "
  141. f"group_size = {self.quant_config.group_size}.")
  142. # Check that we have at least 4 tiles horizontally in the shard
  143. num_tiles_per_perm = self.quant_config.perm_len // (
  144. self.quant_config.tile_size**2)
  145. if output_size_per_partition % num_tiles_per_perm != 0:
  146. raise ValueError(
  147. "Each permutation group must reside on the same gpu")
  148. # Quantized 4Bit weights packed into Int32.
  149. qweight = Parameter(
  150. torch.empty(
  151. input_size_per_partition // self.quant_config.tile_size // 2,
  152. output_size_per_partition * self.quant_config.tile_size //
  153. self.quant_config.pack_factor,
  154. device="cuda",
  155. dtype=torch.int32,
  156. ),
  157. requires_grad=False,
  158. )
  159. set_weight_attrs(
  160. qweight,
  161. {
  162. "input_dim": 0,
  163. "output_dim": 1,
  164. "packed_dim": 1,
  165. "pack_factor": self.quant_config.pack_factor,
  166. "marlin_tile_size": self.quant_config.tile_size,
  167. },
  168. )
  169. # Meta
  170. meta = Parameter(
  171. torch.empty(
  172. input_size_per_partition // 8 // 2 // 2,
  173. output_size_per_partition * 2,
  174. device="cuda",
  175. dtype=torch.int16,
  176. ),
  177. requires_grad=False,
  178. )
  179. set_weight_attrs(
  180. meta,
  181. {
  182. "input_dim": 0,
  183. "packed_dim": 1,
  184. "pack_factor": 1,
  185. "output_dim": 1,
  186. "marlin_tile_size": 2,
  187. },
  188. )
  189. # Determine if channelwise or not
  190. input_groups = (1 if self.quant_config.group_size == -1 else
  191. input_size_per_partition //
  192. self.quant_config.group_size)
  193. scales = Parameter(
  194. torch.empty(
  195. input_groups,
  196. output_size_per_partition,
  197. device="cuda",
  198. dtype=params_dtype,
  199. ),
  200. requires_grad=False,
  201. )
  202. set_weight_attrs(
  203. scales,
  204. {
  205. "input_dim": None if input_groups == 1 else 0,
  206. "output_dim": 1,
  207. },
  208. )
  209. # Allocate workspace (Used for internal locking mechanism)
  210. max_workspace_size = (
  211. output_size_per_partition //
  212. self.quant_config.min_n_threads) * self.quant_config.max_parallel
  213. workspace = Parameter(torch.zeros(max_workspace_size,
  214. device="cuda",
  215. dtype=torch.int),
  216. requires_grad=False)
  217. layer.register_parameter("B_24", qweight)
  218. set_weight_attrs(qweight, extra_weight_attrs)
  219. layer.register_parameter("B_meta", meta)
  220. set_weight_attrs(meta, extra_weight_attrs)
  221. layer.register_parameter("s", scales)
  222. set_weight_attrs(scales, extra_weight_attrs)
  223. layer.register_parameter("workspace", workspace)
  224. set_weight_attrs(workspace, extra_weight_attrs)
  225. def apply(
  226. self,
  227. layer: torch.nn.Module,
  228. x: torch.Tensor,
  229. bias: Optional[torch.Tensor] = None,
  230. ) -> torch.Tensor:
  231. qweight = layer.B_24
  232. meta = layer.B_meta
  233. scales = layer.s
  234. workspace = layer.workspace
  235. x_2d = x.view(-1, x.shape[-1])
  236. size_m = x_2d.shape[0]
  237. size_k = x_2d.shape[1]
  238. size_n = scales.shape[1]
  239. output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
  240. workspace,
  241. self.quant_config.quant_type,
  242. size_m, size_n, size_k)
  243. output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
  244. if bias is not None:
  245. output.add_(bias) # In-place add
  246. return output