1
0

marlin.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. from loguru import logger
  5. from torch.nn.parameter import Parameter
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.modeling.utils import set_weight_attrs
  8. from aphrodite.quantization.base_config import QuantizationConfig
  9. HAS_QUANTS = False
  10. with suppress(ImportError):
  11. from aphrodite._quant_C import quant_ops as ops
  12. HAS_QUANTS = True
  13. class MarlinConfig(QuantizationConfig):
  14. """Config class for Marlin.
  15. Reference: https://github.com/IST-DASLab/marlin/tree/master
  16. """
  17. def __init__(
  18. self,
  19. group_size: int,
  20. ) -> None:
  21. # Group size for the quantization.
  22. self.group_size = group_size
  23. if self.group_size != 128 and self.group_size != -1:
  24. raise ValueError(
  25. "Currently, only group size 128 and -1 (channelwise) "
  26. "is supported for Marlin, but got group_size of "
  27. f"{self.group_size}")
  28. # 4 Bits packed into 32 bit datatype.
  29. self.pack_factor = 32 // 4
  30. # Tile size used by marlin kernels.
  31. self.tile_size = 16
  32. # Min out_features dim
  33. self.min_n_threads = 64
  34. # Min in_features dim
  35. self.min_k_threads = 128
  36. # Max parallel problems to solve at once (improves large
  37. # batch performance)
  38. self.max_parallel = 16
  39. # Permutation length used by the marlin kernels.
  40. self.perm_len = 1024
  41. def __repr__(self) -> str:
  42. return f"MarlinConfig(group_size={self.group_size})"
  43. @classmethod
  44. def get_name(cls) -> str:
  45. return "marlin"
  46. @classmethod
  47. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  48. return [torch.half]
  49. @classmethod
  50. # Need to figure it out
  51. def get_min_capability(cls) -> int:
  52. return 80
  53. @classmethod
  54. def get_config_filenames(cls) -> List[str]:
  55. return ["quantize_config.json"]
  56. @classmethod
  57. def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
  58. group_size = cls.get_from_keys(config, ["group_size"])
  59. return cls(group_size)
  60. @classmethod
  61. def override_quantization_method(cls, hf_quant_cfg,
  62. user_quant) -> Optional[str]:
  63. # compat: autogptq >=0.8.0 use checkpoint_format: str
  64. # compat: autogptq <=0.7.1 is_marlin_format: bool
  65. is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
  66. or hf_quant_cfg.get("is_marlin_format", False))
  67. is_valid_user_quant = (user_quant is None or user_quant == "gptq"
  68. or user_quant == "marlin")
  69. if is_marlin_format and is_valid_user_quant:
  70. msg = ("The model is serialized in {} format. Using {} kernel.".
  71. format(cls.get_name(), cls.get_name()))
  72. logger.info(msg)
  73. return cls.get_name()
  74. return None
  75. def get_quant_method(
  76. self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
  77. if isinstance(layer, LinearBase):
  78. return MarlinLinearMethod(self)
  79. return None
  80. def get_scaled_act_names(self) -> List[str]:
  81. return []
  82. class MarlinLinearMethod(LinearMethodBase):
  83. """Linear method for Marlin.
  84. Args:
  85. quant_config: The Marlin quantization config.
  86. """
  87. def __init__(self, quant_config: MarlinConfig):
  88. if not HAS_QUANTS:
  89. raise ImportError("Could not find the quantization kernels.")
  90. self.quant_config = quant_config
  91. def create_weights(
  92. self,
  93. layer: torch.nn.Module,
  94. input_size_per_partition: int,
  95. output_partition_sizes: List[int],
  96. input_size: int,
  97. output_size: int,
  98. params_dtype: torch.dtype,
  99. **extra_weight_attrs,
  100. ):
  101. del output_size # Unused.
  102. if params_dtype != torch.float16:
  103. raise ValueError(
  104. f"The params dtype must be float16, but got {params_dtype}")
  105. # Validate output_size_per_partition
  106. output_size_per_partition = sum(output_partition_sizes)
  107. if output_size_per_partition % self.quant_config.min_n_threads != 0:
  108. raise ValueError(
  109. f"Weight output_size_per_partition = "
  110. f"{output_size_per_partition} is not divisible by "
  111. f"min_n_threads = {self.quant_config.min_n_threads}.")
  112. if output_size_per_partition % self.quant_config.pack_factor != 0:
  113. raise ValueError(
  114. f"Weight output_size_per_partition = "
  115. f"{output_size_per_partition} is not divisible by "
  116. f"pack_factor = {self.quant_config.pack_factor}.")
  117. # Validate input_size_per_partition
  118. if input_size_per_partition % self.quant_config.min_k_threads != 0:
  119. raise ValueError(
  120. f"Weight input_size_per_partition = "
  121. f"{input_size_per_partition} is not divisible by "
  122. f"min_k_threads = {self.quant_config.min_k_threads}.")
  123. if (self.quant_config.group_size != -1 and
  124. input_size_per_partition % self.quant_config.group_size != 0):
  125. raise ValueError(f"Weight input_size_per_partition = "
  126. f"{input_size_per_partition} is not divisible by "
  127. f"group_size = {self.quant_config.group_size}.")
  128. # Check that we have at least 4 tiles horizontally in the shard
  129. num_tiles_per_perm = self.quant_config.perm_len // (
  130. self.quant_config.tile_size**2)
  131. if output_size_per_partition % num_tiles_per_perm != 0:
  132. raise ValueError(
  133. "Each permutation group must reside on the same gpu")
  134. # Quantized 4Bit weights packed into Int32.
  135. qweight = Parameter(
  136. torch.empty(
  137. input_size_per_partition // self.quant_config.tile_size,
  138. output_size_per_partition * self.quant_config.tile_size //
  139. self.quant_config.pack_factor,
  140. device="cuda",
  141. dtype=torch.int32,
  142. ),
  143. requires_grad=False,
  144. )
  145. set_weight_attrs(
  146. qweight,
  147. {
  148. "input_dim": 0,
  149. "output_dim": 1,
  150. "packed_dim": 1,
  151. "pack_factor": self.quant_config.pack_factor,
  152. "marlin_tile_size": self.quant_config.tile_size,
  153. },
  154. )
  155. # Determine if channelwise or not
  156. input_groups = (1 if self.quant_config.group_size == -1 else
  157. input_size_per_partition //
  158. self.quant_config.group_size)
  159. scales = Parameter(
  160. torch.empty(
  161. input_groups,
  162. output_size_per_partition,
  163. device="cuda",
  164. dtype=params_dtype,
  165. ),
  166. requires_grad=False,
  167. )
  168. set_weight_attrs(
  169. scales,
  170. {
  171. "input_dim": None if input_groups == 1 else 0,
  172. "output_dim": 1,
  173. },
  174. )
  175. # Allocate workspace (Used for internal locking mechanism)
  176. max_workspace_size = (
  177. output_size_per_partition //
  178. self.quant_config.min_n_threads) * self.quant_config.max_parallel
  179. workspace = Parameter(torch.zeros(max_workspace_size,
  180. device="cuda",
  181. dtype=torch.int),
  182. requires_grad=False)
  183. layer.register_parameter("B", qweight)
  184. set_weight_attrs(qweight, extra_weight_attrs)
  185. layer.register_parameter("s", scales)
  186. set_weight_attrs(scales, extra_weight_attrs)
  187. layer.register_parameter("workspace", workspace)
  188. set_weight_attrs(workspace, extra_weight_attrs)
  189. def apply(
  190. self,
  191. layer: torch.nn.Module,
  192. x: torch.Tensor,
  193. bias: Optional[torch.Tensor] = None,
  194. ) -> torch.Tensor:
  195. qweight = layer.B
  196. scales = layer.s
  197. workspace = layer.workspace
  198. x_2d = x.view(-1, x.shape[-1])
  199. size_m = x_2d.shape[0]
  200. size_k = x_2d.shape[1]
  201. size_n = scales.shape[1]
  202. output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
  203. size_n, size_k)
  204. output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
  205. if bias is not None:
  206. output.add_(bias) # In-place add
  207. return output