marlin.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn.parameter import Parameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  8. from aphrodite.modeling.utils import set_weight_attrs
  9. from aphrodite.quantization.base_config import QuantizationConfig
  10. class MarlinConfig(QuantizationConfig):
  11. """Config class for Marlin.
  12. Reference: https://github.com/IST-DASLab/marlin/tree/master
  13. """
  14. def __init__(
  15. self,
  16. group_size: int,
  17. lm_head_quantized: bool,
  18. ) -> None:
  19. # Group size for the quantization.
  20. self.group_size = group_size
  21. self.lm_head_quantized = lm_head_quantized
  22. if self.group_size != 128 and self.group_size != -1:
  23. raise ValueError(
  24. "Currently, only group size 128 and -1 (channelwise) "
  25. "is supported for Marlin, but got group_size of "
  26. f"{self.group_size}")
  27. # 4 Bits packed into 32 bit datatype.
  28. self.pack_factor = 32 // 4
  29. # Tile size used by marlin kernels.
  30. self.tile_size = 16
  31. # Min out_features dim
  32. self.min_n_threads = 64
  33. # Min in_features dim
  34. self.min_k_threads = 128
  35. # Max parallel problems to solve at once (improves large
  36. # batch performance)
  37. self.max_parallel = 16
  38. # Permutation length used by the marlin kernels.
  39. self.perm_len = 1024
  40. def __repr__(self) -> str:
  41. return (f"MarlinConfig(group_size={self.group_size}, "
  42. f"lm_head_quantized={self.lm_head_quantized})")
  43. @classmethod
  44. def get_name(cls) -> str:
  45. return "marlin"
  46. @classmethod
  47. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  48. return [torch.half]
  49. @classmethod
  50. # Need to figure it out
  51. def get_min_capability(cls) -> int:
  52. return 80
  53. @classmethod
  54. def get_config_filenames(cls) -> List[str]:
  55. return ["quantize_config.json"]
  56. @classmethod
  57. def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
  58. group_size = cls.get_from_keys(config, ["group_size"])
  59. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  60. default=False)
  61. return cls(group_size, lm_head_quantized)
  62. @classmethod
  63. def override_quantization_method(cls, hf_quant_cfg,
  64. user_quant) -> Optional[str]:
  65. # compat: autogptq >=0.8.0 use checkpoint_format: str
  66. # compat: autogptq <=0.7.1 is_marlin_format: bool
  67. is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
  68. or hf_quant_cfg.get("is_marlin_format", False))
  69. is_valid_user_quant = (user_quant is None or user_quant == "gptq"
  70. or user_quant == "marlin")
  71. if is_marlin_format and is_valid_user_quant:
  72. msg = ("The model is serialized in {} format. Using {} kernel.".
  73. format(cls.get_name(), cls.get_name()))
  74. logger.info(msg)
  75. return cls.get_name()
  76. return None
  77. def get_quant_method(self, layer: torch.nn.Module,
  78. prefix: str) -> Optional["MarlinLinearMethod"]:
  79. if (isinstance(layer, LinearBase) or
  80. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  81. return MarlinLinearMethod(self)
  82. return None
  83. def get_scaled_act_names(self) -> List[str]:
  84. return []
  85. class MarlinLinearMethod(LinearMethodBase):
  86. """Linear method for Marlin.
  87. Args:
  88. quant_config: The Marlin quantization config.
  89. """
  90. def __init__(self, quant_config: MarlinConfig):
  91. self.quant_config = quant_config
  92. def create_weights(
  93. self,
  94. layer: torch.nn.Module,
  95. input_size_per_partition: int,
  96. output_partition_sizes: List[int],
  97. input_size: int,
  98. output_size: int,
  99. params_dtype: torch.dtype,
  100. **extra_weight_attrs,
  101. ):
  102. del output_size # Unused.
  103. if params_dtype != torch.float16:
  104. raise ValueError(
  105. f"The params dtype must be float16, but got {params_dtype}")
  106. # Validate output_size_per_partition
  107. output_size_per_partition = sum(output_partition_sizes)
  108. if output_size_per_partition % self.quant_config.min_n_threads != 0:
  109. raise ValueError(
  110. f"Weight output_size_per_partition = "
  111. f"{output_size_per_partition} is not divisible by "
  112. f"min_n_threads = {self.quant_config.min_n_threads}.")
  113. if output_size_per_partition % self.quant_config.pack_factor != 0:
  114. raise ValueError(
  115. f"Weight output_size_per_partition = "
  116. f"{output_size_per_partition} is not divisible by "
  117. f"pack_factor = {self.quant_config.pack_factor}.")
  118. # Validate input_size_per_partition
  119. if input_size_per_partition % self.quant_config.min_k_threads != 0:
  120. raise ValueError(
  121. f"Weight input_size_per_partition = "
  122. f"{input_size_per_partition} is not divisible by "
  123. f"min_k_threads = {self.quant_config.min_k_threads}.")
  124. if (self.quant_config.group_size != -1 and
  125. input_size_per_partition % self.quant_config.group_size != 0):
  126. raise ValueError(f"Weight input_size_per_partition = "
  127. f"{input_size_per_partition} is not divisible by "
  128. f"group_size = {self.quant_config.group_size}.")
  129. # Check that we have at least 4 tiles horizontally in the shard
  130. num_tiles_per_perm = self.quant_config.perm_len // (
  131. self.quant_config.tile_size**2)
  132. if output_size_per_partition % num_tiles_per_perm != 0:
  133. raise ValueError(
  134. "Each permutation group must reside on the same gpu")
  135. # Quantized 4Bit weights packed into Int32.
  136. qweight = Parameter(
  137. torch.empty(
  138. input_size_per_partition // self.quant_config.tile_size,
  139. output_size_per_partition * self.quant_config.tile_size //
  140. self.quant_config.pack_factor,
  141. device="cuda",
  142. dtype=torch.int32,
  143. ),
  144. requires_grad=False,
  145. )
  146. set_weight_attrs(
  147. qweight,
  148. {
  149. "input_dim": 0,
  150. "output_dim": 1,
  151. "packed_dim": 1,
  152. "pack_factor": self.quant_config.pack_factor,
  153. "marlin_tile_size": self.quant_config.tile_size,
  154. },
  155. )
  156. # Determine if channelwise or not
  157. input_groups = (1 if self.quant_config.group_size == -1 else
  158. input_size_per_partition //
  159. self.quant_config.group_size)
  160. scales = Parameter(
  161. torch.empty(
  162. input_groups,
  163. output_size_per_partition,
  164. device="cuda",
  165. dtype=params_dtype,
  166. ),
  167. requires_grad=False,
  168. )
  169. set_weight_attrs(
  170. scales,
  171. {
  172. "input_dim": None if input_groups == 1 else 0,
  173. "output_dim": 1,
  174. },
  175. )
  176. # Allocate workspace (Used for internal locking mechanism)
  177. max_workspace_size = (
  178. output_size_per_partition //
  179. self.quant_config.min_n_threads) * self.quant_config.max_parallel
  180. workspace = Parameter(torch.zeros(max_workspace_size,
  181. device="cuda",
  182. dtype=torch.int),
  183. requires_grad=False)
  184. layer.register_parameter("B", qweight)
  185. set_weight_attrs(qweight, extra_weight_attrs)
  186. layer.register_parameter("s", scales)
  187. set_weight_attrs(scales, extra_weight_attrs)
  188. layer.register_parameter("workspace", workspace)
  189. set_weight_attrs(workspace, extra_weight_attrs)
  190. def apply(
  191. self,
  192. layer: torch.nn.Module,
  193. x: torch.Tensor,
  194. bias: Optional[torch.Tensor] = None,
  195. ) -> torch.Tensor:
  196. qweight = layer.B
  197. scales = layer.s
  198. workspace = layer.workspace
  199. x_2d = x.view(-1, x.shape[-1])
  200. size_m = x_2d.shape[0]
  201. size_k = x_2d.shape[1]
  202. size_n = scales.shape[1]
  203. output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
  204. size_n, size_k)
  205. output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
  206. if bias is not None:
  207. output.add_(bias) # In-place add
  208. return output