|
@@ -7,7 +7,10 @@ from torch.nn.parameter import Parameter
|
|
|
from aphrodite import _custom_ops as ops
|
|
|
from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
|
|
|
from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
|
|
|
-from aphrodite.modeling.utils import set_weight_attrs
|
|
|
+from aphrodite.modeling.parameter import (BaseAphroditeParameter,
|
|
|
+ ChannelQuantScaleParameter,
|
|
|
+ GroupQuantScaleParameter,
|
|
|
+ PackedAphroditeParameter)
|
|
|
from aphrodite.quantization.base_config import QuantizationConfig
|
|
|
|
|
|
|
|
@@ -29,7 +32,8 @@ class MarlinConfig(QuantizationConfig):
|
|
|
raise ValueError(
|
|
|
"Currently, only group size 128 and -1 (channelwise) "
|
|
|
"is supported for Marlin, but got group_size of "
|
|
|
- f"{self.group_size}")
|
|
|
+ f"{self.group_size}"
|
|
|
+ )
|
|
|
|
|
|
# 4 Bits packed into 32 bit datatype.
|
|
|
self.pack_factor = 32 // 4
|
|
@@ -51,8 +55,10 @@ class MarlinConfig(QuantizationConfig):
|
|
|
self.perm_len = 1024
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
- return (f"MarlinConfig(group_size={self.group_size}, "
|
|
|
- f"lm_head_quantized={self.lm_head_quantized})")
|
|
|
+ return (
|
|
|
+ f"MarlinConfig(group_size={self.group_size}, "
|
|
|
+ f"lm_head_quantized={self.lm_head_quantized})"
|
|
|
+ )
|
|
|
|
|
|
@classmethod
|
|
|
def get_name(cls) -> str:
|
|
@@ -74,33 +80,42 @@ class MarlinConfig(QuantizationConfig):
|
|
|
@classmethod
|
|
|
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
|
|
|
group_size = cls.get_from_keys(config, ["group_size"])
|
|
|
- lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
|
|
- default=False)
|
|
|
+ lm_head_quantized = cls.get_from_keys_or(
|
|
|
+ config, ["lm_head"], default=False
|
|
|
+ )
|
|
|
return cls(group_size, lm_head_quantized)
|
|
|
|
|
|
@classmethod
|
|
|
- def override_quantization_method(cls, hf_quant_cfg,
|
|
|
- user_quant) -> Optional[str]:
|
|
|
+ def override_quantization_method(
|
|
|
+ cls, hf_quant_cfg, user_quant
|
|
|
+ ) -> Optional[str]:
|
|
|
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
|
|
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
|
|
- is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
|
|
|
- or hf_quant_cfg.get("is_marlin_format", False))
|
|
|
+ is_marlin_format = hf_quant_cfg.get(
|
|
|
+ "checkpoint_format"
|
|
|
+ ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
|
|
|
|
|
|
- is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
|
|
- or user_quant == "marlin")
|
|
|
+ is_valid_user_quant = (
|
|
|
+ user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
|
|
+ )
|
|
|
|
|
|
if is_marlin_format and is_valid_user_quant:
|
|
|
- msg = ("The model is serialized in {} format. Using {} kernel.".
|
|
|
- format(cls.get_name(), cls.get_name()))
|
|
|
+ msg = (
|
|
|
+ "The model is serialized in {} format. Using {} kernel.".format(
|
|
|
+ cls.get_name(), cls.get_name()
|
|
|
+ )
|
|
|
+ )
|
|
|
logger.info(msg)
|
|
|
return cls.get_name()
|
|
|
|
|
|
return None
|
|
|
|
|
|
- def get_quant_method(self, layer: torch.nn.Module,
|
|
|
- prefix: str) -> Optional["MarlinLinearMethod"]:
|
|
|
- if (isinstance(layer, LinearBase) or
|
|
|
- (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
|
|
+ def get_quant_method(
|
|
|
+ self, layer: torch.nn.Module, prefix: str
|
|
|
+ ) -> Optional["MarlinLinearMethod"]:
|
|
|
+ if isinstance(layer, LinearBase) or (
|
|
|
+ isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
|
|
+ ):
|
|
|
return MarlinLinearMethod(self)
|
|
|
return None
|
|
|
|
|
@@ -129,10 +144,12 @@ class MarlinLinearMethod(LinearMethodBase):
|
|
|
**extra_weight_attrs,
|
|
|
):
|
|
|
del output_size # Unused.
|
|
|
+ weight_loader = extra_weight_attrs["weight_loader"]
|
|
|
|
|
|
if params_dtype != torch.float16:
|
|
|
raise ValueError(
|
|
|
- f"The params dtype must be float16, but got {params_dtype}")
|
|
|
+ f"The params dtype must be float16, but got {params_dtype}"
|
|
|
+ )
|
|
|
|
|
|
# Validate output_size_per_partition
|
|
|
output_size_per_partition = sum(output_partition_sizes)
|
|
@@ -140,91 +157,104 @@ class MarlinLinearMethod(LinearMethodBase):
|
|
|
raise ValueError(
|
|
|
f"Weight output_size_per_partition = "
|
|
|
f"{output_size_per_partition} is not divisible by "
|
|
|
- f"min_n_threads = {self.quant_config.min_n_threads}.")
|
|
|
+ f"min_n_threads = {self.quant_config.min_n_threads}."
|
|
|
+ )
|
|
|
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
|
|
raise ValueError(
|
|
|
f"Weight output_size_per_partition = "
|
|
|
f"{output_size_per_partition} is not divisible by "
|
|
|
- f"pack_factor = {self.quant_config.pack_factor}.")
|
|
|
+ f"pack_factor = {self.quant_config.pack_factor}."
|
|
|
+ )
|
|
|
|
|
|
# Validate input_size_per_partition
|
|
|
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
|
|
raise ValueError(
|
|
|
f"Weight input_size_per_partition = "
|
|
|
f"{input_size_per_partition} is not divisible by "
|
|
|
- f"min_k_threads = {self.quant_config.min_k_threads}.")
|
|
|
- if (self.quant_config.group_size != -1 and
|
|
|
- input_size_per_partition % self.quant_config.group_size != 0):
|
|
|
- raise ValueError(f"Weight input_size_per_partition = "
|
|
|
- f"{input_size_per_partition} is not divisible by "
|
|
|
- f"group_size = {self.quant_config.group_size}.")
|
|
|
+ f"min_k_threads = {self.quant_config.min_k_threads}."
|
|
|
+ )
|
|
|
+ if (
|
|
|
+ self.quant_config.group_size != -1
|
|
|
+ and input_size_per_partition % self.quant_config.group_size != 0
|
|
|
+ ):
|
|
|
+ raise ValueError(
|
|
|
+ f"Weight input_size_per_partition = "
|
|
|
+ f"{input_size_per_partition} is not divisible by "
|
|
|
+ f"group_size = {self.quant_config.group_size}."
|
|
|
+ )
|
|
|
|
|
|
# Check that we have at least 4 tiles horizontally in the shard
|
|
|
num_tiles_per_perm = self.quant_config.perm_len // (
|
|
|
- self.quant_config.tile_size**2)
|
|
|
+ self.quant_config.tile_size**2
|
|
|
+ )
|
|
|
if output_size_per_partition % num_tiles_per_perm != 0:
|
|
|
raise ValueError(
|
|
|
- "Each permutation group must reside on the same gpu")
|
|
|
+ "Each permutation group must reside on the same gpu"
|
|
|
+ )
|
|
|
|
|
|
# Quantized 4Bit weights packed into Int32.
|
|
|
- qweight = Parameter(
|
|
|
- torch.empty(
|
|
|
+ qweight = PackedAphroditeParameter(
|
|
|
+ data=torch.empty(
|
|
|
input_size_per_partition // self.quant_config.tile_size,
|
|
|
- output_size_per_partition * self.quant_config.tile_size //
|
|
|
- self.quant_config.pack_factor,
|
|
|
+ output_size_per_partition
|
|
|
+ * self.quant_config.tile_size
|
|
|
+ // self.quant_config.pack_factor,
|
|
|
device="cuda",
|
|
|
dtype=torch.int32,
|
|
|
),
|
|
|
- requires_grad=False,
|
|
|
- )
|
|
|
- set_weight_attrs(
|
|
|
- qweight,
|
|
|
- {
|
|
|
- "input_dim": 0,
|
|
|
- "output_dim": 1,
|
|
|
- "packed_dim": 1,
|
|
|
- "pack_factor": self.quant_config.pack_factor,
|
|
|
- "marlin_tile_size": self.quant_config.tile_size,
|
|
|
- },
|
|
|
+ input_dim=0,
|
|
|
+ output_dim=1,
|
|
|
+ packed_dim=1,
|
|
|
+ packed_factor=self.quant_config.pack_factor,
|
|
|
+ marlin_tile_size=self.quant_config.tile_size,
|
|
|
+ weight_loader=weight_loader,
|
|
|
)
|
|
|
|
|
|
# Determine if channelwise or not
|
|
|
- input_groups = (1 if self.quant_config.group_size == -1 else
|
|
|
- input_size_per_partition //
|
|
|
- self.quant_config.group_size)
|
|
|
+ input_groups = (
|
|
|
+ 1
|
|
|
+ if self.quant_config.group_size == -1
|
|
|
+ else input_size_per_partition // self.quant_config.group_size
|
|
|
+ )
|
|
|
|
|
|
- scales = Parameter(
|
|
|
- torch.empty(
|
|
|
+ weight_scale_args = {
|
|
|
+ "data": torch.empty(
|
|
|
input_groups,
|
|
|
output_size_per_partition,
|
|
|
device="cuda",
|
|
|
dtype=params_dtype,
|
|
|
),
|
|
|
- requires_grad=False,
|
|
|
- )
|
|
|
- set_weight_attrs(
|
|
|
- scales,
|
|
|
- {
|
|
|
- "input_dim": None if input_groups == 1 else 0,
|
|
|
- "output_dim": 1,
|
|
|
- },
|
|
|
- )
|
|
|
+ "weight_loader": weight_loader,
|
|
|
+ }
|
|
|
+ if input_groups == 1:
|
|
|
+ scales = ChannelQuantScaleParameter(
|
|
|
+ output_dim=1, **weight_scale_args
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ scales = GroupQuantScaleParameter(
|
|
|
+ output_dim=1, input_dim=0, **weight_scale_args
|
|
|
+ )
|
|
|
|
|
|
# Allocate workspace (Used for internal locking mechanism)
|
|
|
max_workspace_size = (
|
|
|
- output_size_per_partition //
|
|
|
- self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
|
|
- workspace = Parameter(torch.zeros(max_workspace_size,
|
|
|
- device="cuda",
|
|
|
- dtype=torch.int),
|
|
|
- requires_grad=False)
|
|
|
+ output_size_per_partition // self.quant_config.min_n_threads
|
|
|
+ ) * self.quant_config.max_parallel
|
|
|
+ workspace = BaseAphroditeParameter(
|
|
|
+ data=torch.zeros(
|
|
|
+ max_workspace_size, device="cuda", dtype=torch.int
|
|
|
+ ),
|
|
|
+ weight_loader=weight_loader,
|
|
|
+ )
|
|
|
|
|
|
layer.register_parameter("B", qweight)
|
|
|
- set_weight_attrs(qweight, extra_weight_attrs)
|
|
|
layer.register_parameter("s", scales)
|
|
|
- set_weight_attrs(scales, extra_weight_attrs)
|
|
|
layer.register_parameter("workspace", workspace)
|
|
|
- set_weight_attrs(workspace, extra_weight_attrs)
|
|
|
+
|
|
|
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
|
+ # required by torch.compile
|
|
|
+ layer.B = Parameter(layer.B.data, requires_grad=False)
|
|
|
+ layer.s = Parameter(layer.s.data, requires_grad=False)
|
|
|
+ layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
|
|
|
|
|
|
def apply(
|
|
|
self,
|
|
@@ -242,10 +272,11 @@ class MarlinLinearMethod(LinearMethodBase):
|
|
|
size_k = x_2d.shape[1]
|
|
|
size_n = scales.shape[1]
|
|
|
|
|
|
- output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
|
|
|
- size_n, size_k)
|
|
|
+ output_2d = ops.marlin_gemm(
|
|
|
+ x_2d, qweight, scales, workspace, size_m, size_n, size_k
|
|
|
+ )
|
|
|
|
|
|
- output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
|
|
+ output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
|
|
|
|
|
|
if bias is not None:
|
|
|
output.add_(bias) # In-place add
|