qqq.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.modeling.utils import set_weight_attrs
  7. from aphrodite.quantization.base_config import QuantizationConfig
  8. MARLIN_QQQ_TILE = 16
  9. MARLIN_QQQ_MIN_THREAD_N = 64
  10. MARLIN_QQQ_MIN_THREAD_K = 128
  11. MARLIN_QQQ_MAX_PARALLEL = 16
  12. MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
  13. MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
  14. MARLIN_QQQ_SUPPORTED_SYM = [True]
  15. class QQQConfig(QuantizationConfig):
  16. """Config class for QQQ
  17. Reference: https://arxiv.org/pdf/2406.09904
  18. """
  19. def __init__(
  20. self,
  21. weight_bits: int,
  22. group_size: int,
  23. is_sym: bool = True,
  24. ) -> None:
  25. self.weight_bits = weight_bits
  26. self.group_size = group_size
  27. self.is_sym = is_sym
  28. # Verify
  29. if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS:
  30. raise ValueError(
  31. f"QQQ does not support weight_bits = {self.weight_bits}. "
  32. f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} "
  33. "are supported.")
  34. if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES:
  35. raise ValueError(
  36. f"QQQ does not support group_size = {self.group_size}. "
  37. f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} "
  38. "are supported.")
  39. if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM:
  40. raise ValueError(
  41. f"QQQ does not support is_sym = {self.is_sym}. "
  42. f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.")
  43. # 4 Bits packed into 32 bit datatype.
  44. self.pack_factor = 32 // self.weight_bits
  45. # Tile size used by QQQ kernels.
  46. self.tile_size = MARLIN_QQQ_TILE
  47. # Min out_features dim
  48. self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N
  49. # Min in_features dim
  50. self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K
  51. # Max parallel problems to solve at once (improves large
  52. # batch performance)
  53. self.max_parallel = MARLIN_QQQ_MAX_PARALLEL
  54. # Permutation length used by the QQQ kernels.
  55. self.perm_len = 1024
  56. def __repr__(self) -> str:
  57. return "QQQConfig(weight_bits={}, group_size={})".format(
  58. self.weight_bits, self.group_size)
  59. @classmethod
  60. def get_name(cls) -> str:
  61. return "qqq"
  62. @classmethod
  63. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  64. return [torch.half]
  65. @classmethod
  66. def get_min_capability(cls) -> int:
  67. return 80
  68. @classmethod
  69. def get_config_filenames(cls) -> List[str]:
  70. """List of filenames to search for in the model directory."""
  71. return [
  72. "quant_config.json",
  73. "quantize_config.json",
  74. ]
  75. @classmethod
  76. def from_config(cls, config: Dict[str, Any]) -> "QQQConfig":
  77. weight_bits = cls.get_from_keys(config, ["wbits"])
  78. group_size = cls.get_from_keys(config, ["group_size"])
  79. return cls(weight_bits, group_size)
  80. def get_quant_method(self, layer: torch.nn.Module,
  81. prefix: str) -> Optional["QQQLinearMethod"]:
  82. if isinstance(layer, LinearBase):
  83. return QQQLinearMethod(self)
  84. return None
  85. def get_scaled_act_names(self) -> List[str]:
  86. return []
  87. class QQQLinearMethod(LinearMethodBase):
  88. """Linear method for QQQ.
  89. Args:
  90. quant_config: The QQQ quantization config.
  91. """
  92. def __init__(self, quant_config: QQQConfig):
  93. self.quant_config = quant_config
  94. def create_weights(
  95. self,
  96. layer: torch.nn.Module,
  97. input_size_per_partition: int,
  98. output_partition_sizes: List[int],
  99. input_size: int,
  100. output_size: int,
  101. params_dtype: torch.dtype,
  102. **extra_weight_attrs,
  103. ):
  104. if params_dtype != torch.float16:
  105. raise ValueError(
  106. f"The params dtype must be float16, but got {params_dtype}")
  107. # Validate output_size_per_partition
  108. output_size_per_partition = sum(output_partition_sizes)
  109. if output_size_per_partition % self.quant_config.min_n_threads != 0:
  110. raise ValueError(
  111. f"Weight output_size_per_partition = "
  112. f"{output_size_per_partition} is not divisible by "
  113. f"min_n_threads = {self.quant_config.min_n_threads}.")
  114. if output_size_per_partition % self.quant_config.pack_factor != 0:
  115. raise ValueError(
  116. f"Weight output_size_per_partition = "
  117. f"{output_size_per_partition} is not divisible by "
  118. f"pack_factor = {self.quant_config.pack_factor}.")
  119. # Validate input_size_per_partition
  120. if input_size_per_partition % self.quant_config.min_k_threads != 0:
  121. raise ValueError(
  122. f"Weight input_size_per_partition = "
  123. f"{input_size_per_partition} is not divisible by "
  124. f"min_k_threads = {self.quant_config.min_k_threads}.")
  125. if (self.quant_config.group_size != -1 and
  126. input_size_per_partition % self.quant_config.group_size != 0):
  127. raise ValueError(f"Weight input_size_per_partition = "
  128. f"{input_size_per_partition} is not divisible by "
  129. f"group_size = {self.quant_config.group_size}.")
  130. # Check that we have at least 4 tiles horizontally in the shard
  131. num_tiles_per_perm = self.quant_config.perm_len // (
  132. self.quant_config.tile_size**2)
  133. if output_size_per_partition % num_tiles_per_perm != 0:
  134. raise ValueError(
  135. "Each permutation group must reside on the same gpu")
  136. # Quantized 4Bit weights packed into Int32.
  137. qweight = Parameter(
  138. torch.empty(
  139. input_size_per_partition // self.quant_config.tile_size,
  140. output_size_per_partition * self.quant_config.tile_size //
  141. self.quant_config.pack_factor,
  142. device="cuda",
  143. dtype=torch.int32,
  144. ),
  145. requires_grad=False,
  146. )
  147. set_weight_attrs(
  148. qweight,
  149. {
  150. "input_dim": 0,
  151. "output_dim": 1,
  152. "packed_dim": 1,
  153. "pack_factor": self.quant_config.pack_factor,
  154. "marlin_tile_size": self.quant_config.tile_size,
  155. },
  156. )
  157. s_channel = Parameter(
  158. torch.empty(
  159. 1,
  160. output_size_per_partition,
  161. device="cuda",
  162. dtype=torch.float,
  163. ),
  164. requires_grad=False,
  165. )
  166. set_weight_attrs(
  167. s_channel,
  168. {
  169. "input_dim": None,
  170. "output_dim": 1,
  171. },
  172. )
  173. if self.quant_config.group_size == -1:
  174. s_group = Parameter(
  175. torch.tensor(
  176. [],
  177. device="cuda",
  178. dtype=torch.half,
  179. ),
  180. requires_grad=False,
  181. )
  182. else:
  183. s_group = Parameter(
  184. torch.empty(
  185. input_size_per_partition // self.quant_config.group_size,
  186. output_size_per_partition,
  187. device="cuda",
  188. dtype=torch.half,
  189. ),
  190. requires_grad=False,
  191. )
  192. set_weight_attrs(
  193. s_group,
  194. {
  195. "input_dim": None if self.quant_config.group_size == -1 else 0,
  196. "output_dim":
  197. None if self.quant_config.group_size == -1 else 1,
  198. },
  199. )
  200. # Allocate workspace (Used for internal locking mechanism)
  201. max_workspace_size = (
  202. output_size_per_partition //
  203. self.quant_config.min_n_threads) * self.quant_config.max_parallel
  204. workspace = Parameter(torch.zeros(max_workspace_size,
  205. device="cuda",
  206. dtype=torch.int),
  207. requires_grad=False)
  208. layer.register_parameter("B", qweight)
  209. set_weight_attrs(qweight, extra_weight_attrs)
  210. layer.register_parameter("s_channel", s_channel)
  211. set_weight_attrs(s_channel, extra_weight_attrs)
  212. layer.register_parameter("s_group", s_group)
  213. set_weight_attrs(s_group, extra_weight_attrs)
  214. layer.register_parameter("workspace", workspace)
  215. set_weight_attrs(workspace, extra_weight_attrs)
  216. def apply(
  217. self,
  218. layer: torch.nn.Module,
  219. x: torch.Tensor,
  220. bias: Optional[torch.Tensor] = None,
  221. ) -> torch.Tensor:
  222. qweight = layer.B
  223. s_ch = layer.s_channel
  224. s_group = layer.s_group
  225. workspace = layer.workspace
  226. x_2d = x.view(-1, x.shape[-1])
  227. size_m = x_2d.shape[0]
  228. size_k = x_2d.shape[1]
  229. size_n = s_ch.shape[1]
  230. x_int8, s_tok = ops.scaled_int8_quant(x_2d)
  231. output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group,
  232. workspace, size_m, size_n, size_k)
  233. output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
  234. if bias is not None:
  235. output.add_(bias) # In-place add
  236. return output