marlin.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn.parameter import Parameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  8. from aphrodite.modeling.parameter import (BaseAphroditeParameter,
  9. ChannelQuantScaleParameter,
  10. GroupQuantScaleParameter,
  11. PackedAphroditeParameter)
  12. from aphrodite.quantization.base_config import QuantizationConfig
  13. class MarlinConfig(QuantizationConfig):
  14. """Config class for Marlin.
  15. Reference: https://github.com/IST-DASLab/marlin/tree/master
  16. """
  17. def __init__(
  18. self,
  19. group_size: int,
  20. lm_head_quantized: bool,
  21. ) -> None:
  22. # Group size for the quantization.
  23. self.group_size = group_size
  24. self.lm_head_quantized = lm_head_quantized
  25. if self.group_size != 128 and self.group_size != -1:
  26. raise ValueError(
  27. "Currently, only group size 128 and -1 (channelwise) "
  28. "is supported for Marlin, but got group_size of "
  29. f"{self.group_size}"
  30. )
  31. # 4 Bits packed into 32 bit datatype.
  32. self.pack_factor = 32 // 4
  33. # Tile size used by marlin kernels.
  34. self.tile_size = 16
  35. # Min out_features dim
  36. self.min_n_threads = 64
  37. # Min in_features dim
  38. self.min_k_threads = 128
  39. # Max parallel problems to solve at once (improves large
  40. # batch performance)
  41. self.max_parallel = 16
  42. # Permutation length used by the marlin kernels.
  43. self.perm_len = 1024
  44. def __repr__(self) -> str:
  45. return (
  46. f"MarlinConfig(group_size={self.group_size}, "
  47. f"lm_head_quantized={self.lm_head_quantized})"
  48. )
  49. @classmethod
  50. def get_name(cls) -> str:
  51. return "marlin"
  52. @classmethod
  53. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  54. return [torch.half]
  55. @classmethod
  56. # Need to figure it out
  57. def get_min_capability(cls) -> int:
  58. return 80
  59. @classmethod
  60. def get_config_filenames(cls) -> List[str]:
  61. return ["quantize_config.json"]
  62. @classmethod
  63. def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
  64. group_size = cls.get_from_keys(config, ["group_size"])
  65. lm_head_quantized = cls.get_from_keys_or(
  66. config, ["lm_head"], default=False
  67. )
  68. return cls(group_size, lm_head_quantized)
  69. @classmethod
  70. def override_quantization_method(
  71. cls, hf_quant_cfg, user_quant
  72. ) -> Optional[str]:
  73. # compat: autogptq >=0.8.0 use checkpoint_format: str
  74. # compat: autogptq <=0.7.1 is_marlin_format: bool
  75. is_marlin_format = hf_quant_cfg.get(
  76. "checkpoint_format"
  77. ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
  78. is_valid_user_quant = (
  79. user_quant is None or user_quant == "gptq" or user_quant == "marlin"
  80. )
  81. if is_marlin_format and is_valid_user_quant:
  82. msg = (
  83. "The model is serialized in {} format. Using {} kernel.".format(
  84. cls.get_name(), cls.get_name()
  85. )
  86. )
  87. logger.info(msg)
  88. return cls.get_name()
  89. return None
  90. def get_quant_method(
  91. self, layer: torch.nn.Module, prefix: str
  92. ) -> Optional["MarlinLinearMethod"]:
  93. if isinstance(layer, LinearBase) or (
  94. isinstance(layer, ParallelLMHead) and self.lm_head_quantized
  95. ):
  96. return MarlinLinearMethod(self)
  97. return None
  98. def get_scaled_act_names(self) -> List[str]:
  99. return []
  100. class MarlinLinearMethod(LinearMethodBase):
  101. """Linear method for Marlin.
  102. Args:
  103. quant_config: The Marlin quantization config.
  104. """
  105. def __init__(self, quant_config: MarlinConfig):
  106. self.quant_config = quant_config
  107. def create_weights(
  108. self,
  109. layer: torch.nn.Module,
  110. input_size_per_partition: int,
  111. output_partition_sizes: List[int],
  112. input_size: int,
  113. output_size: int,
  114. params_dtype: torch.dtype,
  115. **extra_weight_attrs,
  116. ):
  117. del output_size # Unused.
  118. weight_loader = extra_weight_attrs["weight_loader"]
  119. if params_dtype != torch.float16:
  120. raise ValueError(
  121. f"The params dtype must be float16, but got {params_dtype}"
  122. )
  123. # Validate output_size_per_partition
  124. output_size_per_partition = sum(output_partition_sizes)
  125. if output_size_per_partition % self.quant_config.min_n_threads != 0:
  126. raise ValueError(
  127. f"Weight output_size_per_partition = "
  128. f"{output_size_per_partition} is not divisible by "
  129. f"min_n_threads = {self.quant_config.min_n_threads}."
  130. )
  131. if output_size_per_partition % self.quant_config.pack_factor != 0:
  132. raise ValueError(
  133. f"Weight output_size_per_partition = "
  134. f"{output_size_per_partition} is not divisible by "
  135. f"pack_factor = {self.quant_config.pack_factor}."
  136. )
  137. # Validate input_size_per_partition
  138. if input_size_per_partition % self.quant_config.min_k_threads != 0:
  139. raise ValueError(
  140. f"Weight input_size_per_partition = "
  141. f"{input_size_per_partition} is not divisible by "
  142. f"min_k_threads = {self.quant_config.min_k_threads}."
  143. )
  144. if (
  145. self.quant_config.group_size != -1
  146. and input_size_per_partition % self.quant_config.group_size != 0
  147. ):
  148. raise ValueError(
  149. f"Weight input_size_per_partition = "
  150. f"{input_size_per_partition} is not divisible by "
  151. f"group_size = {self.quant_config.group_size}."
  152. )
  153. # Check that we have at least 4 tiles horizontally in the shard
  154. num_tiles_per_perm = self.quant_config.perm_len // (
  155. self.quant_config.tile_size**2
  156. )
  157. if output_size_per_partition % num_tiles_per_perm != 0:
  158. raise ValueError(
  159. "Each permutation group must reside on the same gpu"
  160. )
  161. # Quantized 4Bit weights packed into Int32.
  162. qweight = PackedAphroditeParameter(
  163. data=torch.empty(
  164. input_size_per_partition // self.quant_config.tile_size,
  165. output_size_per_partition
  166. * self.quant_config.tile_size
  167. // self.quant_config.pack_factor,
  168. device="cuda",
  169. dtype=torch.int32,
  170. ),
  171. input_dim=0,
  172. output_dim=1,
  173. packed_dim=1,
  174. packed_factor=self.quant_config.pack_factor,
  175. marlin_tile_size=self.quant_config.tile_size,
  176. weight_loader=weight_loader,
  177. )
  178. # Determine if channelwise or not
  179. input_groups = (
  180. 1
  181. if self.quant_config.group_size == -1
  182. else input_size_per_partition // self.quant_config.group_size
  183. )
  184. weight_scale_args = {
  185. "data": torch.empty(
  186. input_groups,
  187. output_size_per_partition,
  188. device="cuda",
  189. dtype=params_dtype,
  190. ),
  191. "weight_loader": weight_loader,
  192. }
  193. if input_groups == 1:
  194. scales = ChannelQuantScaleParameter(
  195. output_dim=1, **weight_scale_args
  196. )
  197. else:
  198. scales = GroupQuantScaleParameter(
  199. output_dim=1, input_dim=0, **weight_scale_args
  200. )
  201. # Allocate workspace (Used for internal locking mechanism)
  202. max_workspace_size = (
  203. output_size_per_partition // self.quant_config.min_n_threads
  204. ) * self.quant_config.max_parallel
  205. workspace = BaseAphroditeParameter(
  206. data=torch.zeros(
  207. max_workspace_size, device="cuda", dtype=torch.int
  208. ),
  209. weight_loader=weight_loader,
  210. )
  211. layer.register_parameter("B", qweight)
  212. layer.register_parameter("s", scales)
  213. layer.register_parameter("workspace", workspace)
  214. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  215. # required by torch.compile
  216. layer.B = Parameter(layer.B.data, requires_grad=False)
  217. layer.s = Parameter(layer.s.data, requires_grad=False)
  218. layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
  219. def apply(
  220. self,
  221. layer: torch.nn.Module,
  222. x: torch.Tensor,
  223. bias: Optional[torch.Tensor] = None,
  224. ) -> torch.Tensor:
  225. qweight = layer.B
  226. scales = layer.s
  227. workspace = layer.workspace
  228. x_2d = x.view(-1, x.shape[-1])
  229. size_m = x_2d.shape[0]
  230. size_k = x_2d.shape[1]
  231. size_n = scales.shape[1]
  232. output_2d = ops.marlin_gemm(
  233. x_2d, qweight, scales, workspace, size_m, size_n, size_k
  234. )
  235. output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
  236. if bias is not None:
  237. output.add_(bias) # In-place add
  238. return output