qqq.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.modeling.parameter import (BaseAphroditeParameter,
  7. ChannelQuantScaleParameter,
  8. GroupQuantScaleParameter,
  9. PackedAphroditeParameter)
  10. from aphrodite.quantization.base_config import QuantizationConfig
  11. MARLIN_QQQ_TILE = 16
  12. MARLIN_QQQ_MIN_THREAD_N = 64
  13. MARLIN_QQQ_MIN_THREAD_K = 128
  14. MARLIN_QQQ_MAX_PARALLEL = 16
  15. MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
  16. MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
  17. MARLIN_QQQ_SUPPORTED_SYM = [True]
  18. class QQQConfig(QuantizationConfig):
  19. """Config class for QQQ
  20. Reference: https://arxiv.org/pdf/2406.09904
  21. """
  22. def __init__(
  23. self,
  24. weight_bits: int,
  25. group_size: int,
  26. is_sym: bool = True,
  27. ) -> None:
  28. self.weight_bits = weight_bits
  29. self.group_size = group_size
  30. self.is_sym = is_sym
  31. # Verify
  32. if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS:
  33. raise ValueError(
  34. f"QQQ does not support weight_bits = {self.weight_bits}. "
  35. f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} "
  36. "are supported.")
  37. if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES:
  38. raise ValueError(
  39. f"QQQ does not support group_size = {self.group_size}. "
  40. f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} "
  41. "are supported.")
  42. if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM:
  43. raise ValueError(
  44. f"QQQ does not support is_sym = {self.is_sym}. "
  45. f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.")
  46. # 4 Bits packed into 32 bit datatype.
  47. self.pack_factor = 32 // self.weight_bits
  48. # Tile size used by QQQ kernels.
  49. self.tile_size = MARLIN_QQQ_TILE
  50. # Min out_features dim
  51. self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N
  52. # Min in_features dim
  53. self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K
  54. # Max parallel problems to solve at once (improves large
  55. # batch performance)
  56. self.max_parallel = MARLIN_QQQ_MAX_PARALLEL
  57. # Permutation length used by the QQQ kernels.
  58. self.perm_len = 1024
  59. def __repr__(self) -> str:
  60. return "QQQConfig(weight_bits={}, group_size={})".format(
  61. self.weight_bits, self.group_size)
  62. @classmethod
  63. def get_name(cls) -> str:
  64. return "qqq"
  65. @classmethod
  66. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  67. return [torch.half]
  68. @classmethod
  69. def get_min_capability(cls) -> int:
  70. return 80
  71. @classmethod
  72. def get_config_filenames(cls) -> List[str]:
  73. """List of filenames to search for in the model directory."""
  74. return [
  75. "quant_config.json",
  76. "quantize_config.json",
  77. ]
  78. @classmethod
  79. def from_config(cls, config: Dict[str, Any]) -> "QQQConfig":
  80. weight_bits = cls.get_from_keys(config, ["wbits"])
  81. group_size = cls.get_from_keys(config, ["group_size"])
  82. return cls(weight_bits, group_size)
  83. def get_quant_method(self, layer: torch.nn.Module,
  84. prefix: str) -> Optional["QQQLinearMethod"]:
  85. if isinstance(layer, LinearBase):
  86. return QQQLinearMethod(self)
  87. return None
  88. def get_scaled_act_names(self) -> List[str]:
  89. return []
  90. class QQQLinearMethod(LinearMethodBase):
  91. """Linear method for QQQ.
  92. Args:
  93. quant_config: The QQQ quantization config.
  94. """
  95. def __init__(self, quant_config: QQQConfig):
  96. self.quant_config = quant_config
  97. def create_weights(
  98. self,
  99. layer: torch.nn.Module,
  100. input_size_per_partition: int,
  101. output_partition_sizes: List[int],
  102. input_size: int,
  103. output_size: int,
  104. params_dtype: torch.dtype,
  105. **extra_weight_attrs,
  106. ):
  107. weight_loader = extra_weight_attrs["weight_loader"]
  108. if params_dtype != torch.float16:
  109. raise ValueError(
  110. f"The params dtype must be float16, but got {params_dtype}")
  111. # Validate output_size_per_partition
  112. output_size_per_partition = sum(output_partition_sizes)
  113. if output_size_per_partition % self.quant_config.min_n_threads != 0:
  114. raise ValueError(
  115. f"Weight output_size_per_partition = "
  116. f"{output_size_per_partition} is not divisible by "
  117. f"min_n_threads = {self.quant_config.min_n_threads}.")
  118. if output_size_per_partition % self.quant_config.pack_factor != 0:
  119. raise ValueError(
  120. f"Weight output_size_per_partition = "
  121. f"{output_size_per_partition} is not divisible by "
  122. f"pack_factor = {self.quant_config.pack_factor}.")
  123. # Validate input_size_per_partition
  124. if input_size_per_partition % self.quant_config.min_k_threads != 0:
  125. raise ValueError(
  126. f"Weight input_size_per_partition = "
  127. f"{input_size_per_partition} is not divisible by "
  128. f"min_k_threads = {self.quant_config.min_k_threads}.")
  129. if (self.quant_config.group_size != -1 and
  130. input_size_per_partition % self.quant_config.group_size != 0):
  131. raise ValueError(f"Weight input_size_per_partition = "
  132. f"{input_size_per_partition} is not divisible by "
  133. f"group_size = {self.quant_config.group_size}.")
  134. # Check that we have at least 4 tiles horizontally in the shard
  135. num_tiles_per_perm = self.quant_config.perm_len // (
  136. self.quant_config.tile_size**2)
  137. if output_size_per_partition % num_tiles_per_perm != 0:
  138. raise ValueError(
  139. "Each permutation group must reside on the same gpu")
  140. # Quantized 4Bit weights packed into Int32.
  141. qweight = PackedAphroditeParameter(
  142. data=torch.empty(
  143. input_size_per_partition // self.quant_config.tile_size,
  144. output_size_per_partition * self.quant_config.tile_size //
  145. self.quant_config.pack_factor,
  146. device="cuda",
  147. dtype=torch.int32,
  148. ),
  149. input_dim=0,
  150. output_dim=1,
  151. packed_dim=1,
  152. packed_factor=self.quant_config.pack_factor,
  153. marlin_tile_size=self.quant_config.tile_size,
  154. weight_loader=weight_loader)
  155. s_channel = ChannelQuantScaleParameter(data=torch.empty(
  156. 1,
  157. output_size_per_partition,
  158. device="cuda",
  159. dtype=torch.float,
  160. ),
  161. weight_loader=weight_loader,
  162. output_dim=1)
  163. if self.quant_config.group_size == -1:
  164. s_group_data = torch.tensor(
  165. [],
  166. device="cuda",
  167. dtype=torch.half,
  168. )
  169. else:
  170. s_group_data = torch.empty(
  171. input_size_per_partition // self.quant_config.group_size,
  172. output_size_per_partition,
  173. device="cuda",
  174. dtype=torch.half,
  175. )
  176. s_group_attr = {"data": s_group_data, "weight_loader": weight_loader}
  177. if self.quant_config.group_size == -1:
  178. s_group = BaseAphroditeParameter(**s_group_attr)
  179. else:
  180. s_group = GroupQuantScaleParameter(output_dim=1,
  181. input_dim=0,
  182. **s_group_attr)
  183. # Allocate workspace (Used for internal locking mechanism)
  184. max_workspace_size = (
  185. output_size_per_partition //
  186. self.quant_config.min_n_threads) * self.quant_config.max_parallel
  187. workspace = BaseAphroditeParameter(data=torch.zeros(max_workspace_size,
  188. device="cuda",
  189. dtype=torch.int),
  190. weight_loader=weight_loader)
  191. layer.register_parameter("B", qweight)
  192. layer.register_parameter("s_channel", s_channel)
  193. layer.register_parameter("s_group", s_group)
  194. layer.register_parameter("workspace", workspace)
  195. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  196. # required by torch.compile
  197. layer.B = Parameter(layer.B.data, requires_grad=False)
  198. layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False)
  199. layer.s_group = Parameter(layer.s_group.data, requires_grad=False)
  200. layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
  201. def apply(
  202. self,
  203. layer: torch.nn.Module,
  204. x: torch.Tensor,
  205. bias: Optional[torch.Tensor] = None,
  206. ) -> torch.Tensor:
  207. qweight = layer.B
  208. s_ch = layer.s_channel
  209. s_group = layer.s_group
  210. workspace = layer.workspace
  211. x_2d = x.view(-1, x.shape[-1])
  212. size_m = x_2d.shape[0]
  213. size_k = x_2d.shape[1]
  214. size_n = s_ch.shape[1]
  215. x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d)
  216. output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group,
  217. workspace, size_m, size_n, size_k)
  218. output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
  219. if bias is not None:
  220. output.add_(bias) # In-place add
  221. return output