marlin.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite._C import ops
  5. from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
  6. from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
  7. class MarlinConfig(QuantizationConfig):
  8. """Config class for Marlin.
  9. Reference: https://github.com/IST-DASLab/marlin/tree/master
  10. """
  11. def __init__(
  12. self,
  13. group_size: int,
  14. ) -> None:
  15. # Group size for the quantization.
  16. self.group_size = group_size
  17. if self.group_size != 128 and self.group_size != -1:
  18. raise ValueError(
  19. "Currently, only group size 128 and -1 (channelwise) is supported for "
  20. f"Marlin, but got group_size of {self.group_size}")
  21. # 4 Bits packed into 32 bit datatype.
  22. self.pack_factor = 32 // 4
  23. # Tile size used by marlin kernels.
  24. self.tile_size = 16
  25. # Min out_features dim
  26. self.min_n_threads = 64
  27. # Min in_features dim
  28. self.min_k_threads = 128
  29. # Max parallel problems to solve at once (improves large batch performance)
  30. self.max_parallel = 16
  31. # Permutation length used by the marlin kernels.
  32. self.perm_len = 1024
  33. def __repr__(self) -> str:
  34. return f"MarlinConfig(group_size={self.group_size}"
  35. @classmethod
  36. def get_name(cls) -> str:
  37. return "marlin"
  38. @classmethod
  39. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  40. return [torch.half]
  41. @classmethod
  42. # Need to figure it out
  43. def get_min_capability(cls) -> int:
  44. return 80
  45. @classmethod
  46. def get_config_filenames(cls) -> List[str]:
  47. return ["quantize_config.json"]
  48. @classmethod
  49. def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
  50. group_size = cls.get_from_keys(config, ["group_size"])
  51. return cls(group_size)
  52. def get_linear_method(self) -> "MarlinLinearMethod":
  53. return MarlinLinearMethod(self)
  54. def get_scaled_act_names(self) -> List[str]:
  55. return []
  56. def merge_weight(self) -> bool:
  57. return False
  58. def rope_style(self) -> Optional[bool]:
  59. return None
  60. class MarlinLinearMethod(LinearMethodBase):
  61. """Linear method for Marlin.
  62. Args:
  63. quant_config: The Marlin quantization config.
  64. """
  65. def __init__(self, quant_config: MarlinConfig):
  66. self.quant_config = quant_config
  67. def create_weights(
  68. self,
  69. input_size_per_partition: int,
  70. output_partition_sizes: List[int],
  71. input_size: int,
  72. output_size: int,
  73. params_dtype: torch.dtype,
  74. ) -> Dict[str, Any]:
  75. del output_size # Unused.
  76. if params_dtype != torch.float16:
  77. raise ValueError(
  78. f"The params dtype must be float16, but got {params_dtype}")
  79. output_size_per_partition = sum(output_partition_sizes)
  80. # Validate output_size_per_partition
  81. if output_size_per_partition % self.quant_config.min_n_threads != 0:
  82. raise ValueError(
  83. f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}."
  84. )
  85. if output_size_per_partition % self.quant_config.pack_factor != 0:
  86. raise ValueError(
  87. f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}."
  88. )
  89. # Validate input_size_per_partition
  90. if input_size_per_partition % self.quant_config.min_k_threads != 0:
  91. raise ValueError(
  92. f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}."
  93. )
  94. if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0:
  95. raise ValueError(
  96. f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}."
  97. )
  98. # Check that we have at least 4 tiles horizontally in the shard
  99. num_tiles_per_perm = self.quant_config.perm_len // (
  100. self.quant_config.tile_size**2)
  101. if output_size_per_partition % num_tiles_per_perm != 0:
  102. raise ValueError(
  103. "Each permutation group must reside on the same gpu")
  104. # Quantized 4Bit weights packed into Int32.
  105. qweight = Parameter(
  106. torch.empty(
  107. input_size_per_partition // self.quant_config.tile_size,
  108. output_size_per_partition * self.quant_config.tile_size //
  109. self.quant_config.pack_factor,
  110. device="cuda",
  111. dtype=torch.int32,
  112. ),
  113. requires_grad=False,
  114. )
  115. set_weight_attrs(
  116. qweight,
  117. {
  118. "input_dim": 0,
  119. "output_dim": 1,
  120. "packed_dim": 1,
  121. "pack_factor": self.quant_config.pack_factor,
  122. "marlin_tile_size": self.quant_config.tile_size,
  123. },
  124. )
  125. # Determine if channelwise or not
  126. input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size
  127. scales = Parameter(
  128. torch.empty(
  129. input_groups,
  130. output_size_per_partition,
  131. device="cuda",
  132. dtype=params_dtype,
  133. ),
  134. requires_grad=False,
  135. )
  136. set_weight_attrs(
  137. scales,
  138. {
  139. "input_dim": None if input_groups == 1 else 0,
  140. "output_dim": 1,
  141. },
  142. )
  143. # Allocate workspace (Used for internal locking mechanism)
  144. max_workspace_size = (
  145. output_size_per_partition //
  146. self.quant_config.min_n_threads) * self.quant_config.max_parallel
  147. workspace = Parameter(torch.zeros(max_workspace_size,
  148. device="cuda",
  149. dtype=torch.int),
  150. requires_grad=False)
  151. return {
  152. "B": qweight,
  153. "s": scales,
  154. "workspace": workspace,
  155. }
  156. def apply_weights(
  157. self,
  158. weights: Dict[str, Any],
  159. x: torch.Tensor,
  160. bias: Optional[torch.Tensor] = None,
  161. ) -> torch.Tensor:
  162. qweight = weights["B"]
  163. scales = weights["s"]
  164. workspace = weights["workspace"]
  165. x_2d = x.view(-1, x.shape[-1])
  166. size_m = x_2d.shape[0]
  167. size_k = x_2d.shape[1]
  168. size_n = scales.shape[1]
  169. output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
  170. size_n, size_k)
  171. output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
  172. if bias is not None:
  173. output.add_(bias) # In-place add
  174. return output