marlin.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. from typing import Any, Dict, List, Optional
  2. from contextlib import suppress
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.quantization.base_config import (QuantizationConfig)
  8. HAS_QUANTS = False
  9. with suppress(ImportError):
  10. from aphrodite._quant_C import quant_ops as ops
  11. HAS_QUANTS = True
  12. class MarlinConfig(QuantizationConfig):
  13. """Config class for Marlin.
  14. Reference: https://github.com/IST-DASLab/marlin/tree/master
  15. """
  16. def __init__(
  17. self,
  18. group_size: int,
  19. ) -> None:
  20. if not HAS_QUANTS:
  21. raise ImportError("Could not find the quantization kernels.")
  22. # Group size for the quantization.
  23. self.group_size = group_size
  24. if self.group_size != 128 and self.group_size != -1:
  25. raise ValueError(
  26. "Currently, only group size 128 and -1 (channelwise) "
  27. "is supported for Marlin, but got group_size of "
  28. f"{self.group_size}")
  29. # 4 Bits packed into 32 bit datatype.
  30. self.pack_factor = 32 // 4
  31. # Tile size used by marlin kernels.
  32. self.tile_size = 16
  33. # Min out_features dim
  34. self.min_n_threads = 64
  35. # Min in_features dim
  36. self.min_k_threads = 128
  37. # Max parallel problems to solve at once (improves large
  38. # batch performance)
  39. self.max_parallel = 16
  40. # Permutation length used by the marlin kernels.
  41. self.perm_len = 1024
  42. def __repr__(self) -> str:
  43. return f"MarlinConfig(group_size={self.group_size})"
  44. @classmethod
  45. def get_name(cls) -> str:
  46. return "marlin"
  47. @classmethod
  48. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  49. return [torch.half]
  50. @classmethod
  51. # Need to figure it out
  52. def get_min_capability(cls) -> int:
  53. return 80
  54. @classmethod
  55. def get_config_filenames(cls) -> List[str]:
  56. return ["quantize_config.json"]
  57. @classmethod
  58. def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
  59. group_size = cls.get_from_keys(config, ["group_size"])
  60. return cls(group_size)
  61. def get_linear_method(self) -> "MarlinLinearMethod":
  62. return MarlinLinearMethod(self)
  63. def get_scaled_act_names(self) -> List[str]:
  64. return []
  65. def merge_weight(self) -> bool:
  66. return True
  67. def quant_vocab(self) -> List[bool]:
  68. return [False, False]
  69. def support_fused_moe(self) -> bool:
  70. return False
  71. def rope_style(self) -> Optional[bool]:
  72. return None
  73. class MarlinLinearMethod(LinearMethodBase):
  74. """Linear method for Marlin.
  75. Args:
  76. quant_config: The Marlin quantization config.
  77. """
  78. def __init__(self, quant_config: MarlinConfig):
  79. self.quant_config = quant_config
  80. def create_weights(
  81. self,
  82. input_size_per_partition: int,
  83. output_partition_sizes: List[int],
  84. input_size: int,
  85. output_size: int,
  86. params_dtype: torch.dtype,
  87. ) -> Dict[str, Any]:
  88. del output_size # Unused.
  89. if params_dtype != torch.float16:
  90. raise ValueError(
  91. f"The params dtype must be float16, but got {params_dtype}")
  92. output_size_per_partition = sum(output_partition_sizes)
  93. # Validate output_size_per_partition
  94. if output_size_per_partition % self.quant_config.min_n_threads != 0:
  95. raise ValueError(
  96. f"Weight output_size_per_partition = "
  97. f"{output_size_per_partition} is not divisible by "
  98. f"min_n_threads = {self.quant_config.min_n_threads}.")
  99. if output_size_per_partition % self.quant_config.pack_factor != 0:
  100. raise ValueError(
  101. f"Weight output_size_per_partition = "
  102. f"{output_size_per_partition} is not divisible by "
  103. f"pack_factor = {self.quant_config.pack_factor}.")
  104. # Validate input_size_per_partition
  105. if input_size_per_partition % self.quant_config.min_k_threads != 0:
  106. raise ValueError(
  107. f"Weight input_size_per_partition = "
  108. f"{input_size_per_partition} is not divisible by "
  109. f"min_k_threads = {self.quant_config.min_k_threads}.")
  110. if (self.quant_config.group_size != -1 and
  111. input_size_per_partition % self.quant_config.group_size != 0):
  112. raise ValueError(f"Weight input_size_per_partition = "
  113. f"{input_size_per_partition} is not divisible by "
  114. f"group_size = {self.quant_config.group_size}.")
  115. # Check that we have at least 4 tiles horizontally in the shard
  116. num_tiles_per_perm = self.quant_config.perm_len // (
  117. self.quant_config.tile_size**2)
  118. if output_size_per_partition % num_tiles_per_perm != 0:
  119. raise ValueError(
  120. "Each permutation group must reside on the same gpu")
  121. # Quantized 4Bit weights packed into Int32.
  122. qweight = Parameter(
  123. torch.empty(
  124. input_size_per_partition // self.quant_config.tile_size,
  125. output_size_per_partition * self.quant_config.tile_size //
  126. self.quant_config.pack_factor,
  127. device="cuda",
  128. dtype=torch.int32,
  129. ),
  130. requires_grad=False,
  131. )
  132. set_weight_attrs(
  133. qweight,
  134. {
  135. "input_dim": 0,
  136. "output_dim": 1,
  137. "packed_dim": 1,
  138. "pack_factor": self.quant_config.pack_factor,
  139. "marlin_tile_size": self.quant_config.tile_size,
  140. },
  141. )
  142. # Determine if channelwise or not
  143. input_groups = (1 if self.quant_config.group_size == -1 else
  144. input_size_per_partition //
  145. self.quant_config.group_size)
  146. scales = Parameter(
  147. torch.empty(
  148. input_groups,
  149. output_size_per_partition,
  150. device="cuda",
  151. dtype=params_dtype,
  152. ),
  153. requires_grad=False,
  154. )
  155. set_weight_attrs(
  156. scales,
  157. {
  158. "input_dim": None if input_groups == 1 else 0,
  159. "output_dim": 1,
  160. },
  161. )
  162. # Allocate workspace (Used for internal locking mechanism)
  163. max_workspace_size = (
  164. output_size_per_partition //
  165. self.quant_config.min_n_threads) * self.quant_config.max_parallel
  166. workspace = Parameter(torch.zeros(max_workspace_size,
  167. device="cuda",
  168. dtype=torch.int),
  169. requires_grad=False)
  170. return {
  171. "B": qweight,
  172. "s": scales,
  173. "workspace": workspace,
  174. }
  175. def apply_weights(
  176. self,
  177. weights: Dict[str, Any],
  178. x: torch.Tensor,
  179. bias: Optional[torch.Tensor] = None,
  180. ) -> torch.Tensor:
  181. qweight = weights["B"]
  182. scales = weights["s"]
  183. workspace = weights["workspace"]
  184. x_2d = x.view(-1, x.shape[-1])
  185. size_m = x_2d.shape[0]
  186. size_k = x_2d.shape[1]
  187. size_n = scales.shape[1]
  188. output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
  189. size_n, size_k)
  190. output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
  191. if bias is not None:
  192. output.add_(bias) # In-place add
  193. return output
  194. def apply_moe_weights(self, w1: Dict[str,
  195. torch.Tensor], w2: Dict[str,
  196. torch.Tensor],
  197. x: torch.Tensor, gating_output: torch.Tensor,
  198. topk: int, renormalize: bool) -> torch.Tensor:
  199. raise NotImplementedError