marlin.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite._quant_C import quant_ops as ops
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.modeling.utils import set_weight_attrs
  7. from aphrodite.quantization.base_config import QuantizationConfig
  8. class MarlinConfig(QuantizationConfig):
  9. """Config class for Marlin.
  10. Reference: https://github.com/IST-DASLab/marlin/tree/master
  11. """
  12. def __init__(
  13. self,
  14. group_size: int,
  15. ) -> None:
  16. # Group size for the quantization.
  17. self.group_size = group_size
  18. if self.group_size != 128 and self.group_size != -1:
  19. raise ValueError(
  20. "Currently, only group size 128 and -1 (channelwise) "
  21. "is supported for Marlin, but got group_size of "
  22. f"{self.group_size}")
  23. # 4 Bits packed into 32 bit datatype.
  24. self.pack_factor = 32 // 4
  25. # Tile size used by marlin kernels.
  26. self.tile_size = 16
  27. # Min out_features dim
  28. self.min_n_threads = 64
  29. # Min in_features dim
  30. self.min_k_threads = 128
  31. # Max parallel problems to solve at once (improves large
  32. # batch performance)
  33. self.max_parallel = 16
  34. # Permutation length used by the marlin kernels.
  35. self.perm_len = 1024
  36. def __repr__(self) -> str:
  37. return f"MarlinConfig(group_size={self.group_size})"
  38. @classmethod
  39. def get_name(cls) -> str:
  40. return "marlin"
  41. @classmethod
  42. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  43. return [torch.half]
  44. @classmethod
  45. # Need to figure it out
  46. def get_min_capability(cls) -> int:
  47. return 80
  48. @classmethod
  49. def get_config_filenames(cls) -> List[str]:
  50. return ["quantize_config.json"]
  51. @classmethod
  52. def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
  53. group_size = cls.get_from_keys(config, ["group_size"])
  54. return cls(group_size)
  55. def get_quant_method(
  56. self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
  57. if isinstance(layer, LinearBase):
  58. return MarlinLinearMethod(self)
  59. return None
  60. def get_scaled_act_names(self) -> List[str]:
  61. return []
  62. class MarlinLinearMethod(LinearMethodBase):
  63. """Linear method for Marlin.
  64. Args:
  65. quant_config: The Marlin quantization config.
  66. """
  67. def __init__(self, quant_config: MarlinConfig):
  68. self.quant_config = quant_config
  69. def create_weights(
  70. self,
  71. layer: torch.nn.Module,
  72. input_size_per_partition: int,
  73. output_partition_sizes: List[int],
  74. input_size: int,
  75. output_size: int,
  76. params_dtype: torch.dtype,
  77. **extra_weight_attrs,
  78. ):
  79. del output_size # Unused.
  80. if params_dtype != torch.float16:
  81. raise ValueError(
  82. f"The params dtype must be float16, but got {params_dtype}")
  83. # Validate output_size_per_partition
  84. output_size_per_partition = sum(output_partition_sizes)
  85. if output_size_per_partition % self.quant_config.min_n_threads != 0:
  86. raise ValueError(
  87. f"Weight output_size_per_partition = "
  88. f"{output_size_per_partition} is not divisible by "
  89. f"min_n_threads = {self.quant_config.min_n_threads}.")
  90. if output_size_per_partition % self.quant_config.pack_factor != 0:
  91. raise ValueError(
  92. f"Weight output_size_per_partition = "
  93. f"{output_size_per_partition} is not divisible by "
  94. f"pack_factor = {self.quant_config.pack_factor}.")
  95. # Validate input_size_per_partition
  96. if input_size_per_partition % self.quant_config.min_k_threads != 0:
  97. raise ValueError(
  98. f"Weight input_size_per_partition = "
  99. f"{input_size_per_partition} is not divisible by "
  100. f"min_k_threads = {self.quant_config.min_k_threads}.")
  101. if (self.quant_config.group_size != -1 and
  102. input_size_per_partition % self.quant_config.group_size != 0):
  103. raise ValueError(f"Weight input_size_per_partition = "
  104. f"{input_size_per_partition} is not divisible by "
  105. f"group_size = {self.quant_config.group_size}.")
  106. # Check that we have at least 4 tiles horizontally in the shard
  107. num_tiles_per_perm = self.quant_config.perm_len // (
  108. self.quant_config.tile_size**2)
  109. if output_size_per_partition % num_tiles_per_perm != 0:
  110. raise ValueError(
  111. "Each permutation group must reside on the same gpu")
  112. # Quantized 4Bit weights packed into Int32.
  113. qweight = Parameter(
  114. torch.empty(
  115. input_size_per_partition // self.quant_config.tile_size,
  116. output_size_per_partition * self.quant_config.tile_size //
  117. self.quant_config.pack_factor,
  118. device="cuda",
  119. dtype=torch.int32,
  120. ),
  121. requires_grad=False,
  122. )
  123. set_weight_attrs(
  124. qweight,
  125. {
  126. "input_dim": 0,
  127. "output_dim": 1,
  128. "packed_dim": 1,
  129. "pack_factor": self.quant_config.pack_factor,
  130. "marlin_tile_size": self.quant_config.tile_size,
  131. },
  132. )
  133. # Determine if channelwise or not
  134. input_groups = (1 if self.quant_config.group_size == -1 else
  135. input_size_per_partition //
  136. self.quant_config.group_size)
  137. scales = Parameter(
  138. torch.empty(
  139. input_groups,
  140. output_size_per_partition,
  141. device="cuda",
  142. dtype=params_dtype,
  143. ),
  144. requires_grad=False,
  145. )
  146. set_weight_attrs(
  147. scales,
  148. {
  149. "input_dim": None if input_groups == 1 else 0,
  150. "output_dim": 1,
  151. },
  152. )
  153. # Allocate workspace (Used for internal locking mechanism)
  154. max_workspace_size = (
  155. output_size_per_partition //
  156. self.quant_config.min_n_threads) * self.quant_config.max_parallel
  157. workspace = Parameter(torch.zeros(max_workspace_size,
  158. device="cuda",
  159. dtype=torch.int),
  160. requires_grad=False)
  161. layer.register_parameter("B", qweight)
  162. set_weight_attrs(qweight, extra_weight_attrs)
  163. layer.register_parameter("s", scales)
  164. set_weight_attrs(scales, extra_weight_attrs)
  165. layer.register_parameter("workspace", workspace)
  166. set_weight_attrs(workspace, extra_weight_attrs)
  167. def apply(
  168. self,
  169. layer: torch.nn.Module,
  170. x: torch.Tensor,
  171. bias: Optional[torch.Tensor] = None,
  172. ) -> torch.Tensor:
  173. qweight = layer.B
  174. scales = layer.s
  175. workspace = layer.workspace
  176. x_2d = x.view(-1, x.shape[-1])
  177. size_m = x_2d.shape[0]
  178. size_k = x_2d.shape[1]
  179. size_n = scales.shape[1]
  180. output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
  181. size_n, size_k)
  182. output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
  183. if bias is not None:
  184. output.add_(bias) # In-place add
  185. return output