|
@@ -2,12 +2,12 @@ from typing import Any, Dict, List, Optional
|
|
|
|
|
|
import torch
|
|
|
from loguru import logger
|
|
|
-from torch.nn.parameter import Parameter
|
|
|
|
|
|
from aphrodite import _custom_ops as ops
|
|
|
-from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
|
|
|
- set_weight_attrs)
|
|
|
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
|
|
|
from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
|
|
|
+from aphrodite.modeling.parameter import (GroupQuantScaleParameter,
|
|
|
+ PackedAphroditeParameter)
|
|
|
from aphrodite.quantization.base_config import QuantizationConfig
|
|
|
from aphrodite.quantization.utils.marlin_utils import (
|
|
|
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
|
@@ -147,6 +147,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|
|
) -> None:
|
|
|
del output_size
|
|
|
output_size_per_partition = sum(output_partition_sizes)
|
|
|
+ weight_loader = extra_weight_attrs.get("weight_loader")
|
|
|
|
|
|
# Normalize group_size
|
|
|
if self.quant_config.group_size != -1:
|
|
@@ -160,59 +161,44 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|
|
input_size=input_size,
|
|
|
group_size=group_size)
|
|
|
|
|
|
- qweight = Parameter(
|
|
|
- torch.empty(
|
|
|
+ qweight = PackedAphroditeParameter(
|
|
|
+ data=torch.empty(
|
|
|
input_size_per_partition,
|
|
|
output_size_per_partition // self.quant_config.pack_factor,
|
|
|
dtype=torch.int32,
|
|
|
),
|
|
|
- requires_grad=False,
|
|
|
- )
|
|
|
- set_weight_attrs(
|
|
|
- qweight, {
|
|
|
- "input_dim": 0,
|
|
|
- "output_dim": 1,
|
|
|
- "packed_dim": 1,
|
|
|
- "pack_factor": self.quant_config.pack_factor,
|
|
|
- })
|
|
|
+ input_dim=0,
|
|
|
+ output_dim=1,
|
|
|
+ packed_dim=1,
|
|
|
+ packed_factor=self.quant_config.pack_factor,
|
|
|
+ weight_loader=weight_loader)
|
|
|
|
|
|
num_groups = input_size_per_partition // group_size
|
|
|
|
|
|
- qzeros = Parameter(
|
|
|
- torch.empty(
|
|
|
+ qzeros = PackedAphroditeParameter(
|
|
|
+ data=torch.empty(
|
|
|
num_groups,
|
|
|
output_size_per_partition // self.quant_config.pack_factor,
|
|
|
dtype=torch.int32,
|
|
|
),
|
|
|
- requires_grad=False,
|
|
|
- )
|
|
|
- set_weight_attrs(
|
|
|
- qzeros, {
|
|
|
- "input_dim": 0,
|
|
|
- "output_dim": 1,
|
|
|
- "packed_dim": 1,
|
|
|
- "pack_factor": self.quant_config.pack_factor,
|
|
|
- })
|
|
|
-
|
|
|
- scales = Parameter(
|
|
|
- torch.empty(
|
|
|
- num_groups,
|
|
|
- output_size_per_partition,
|
|
|
- dtype=params_dtype,
|
|
|
- ),
|
|
|
- requires_grad=False,
|
|
|
- )
|
|
|
- set_weight_attrs(scales, {
|
|
|
- "input_dim": 0,
|
|
|
- "output_dim": 1,
|
|
|
- })
|
|
|
+ input_dim=0,
|
|
|
+ output_dim=1,
|
|
|
+ packed_dim=1,
|
|
|
+ packed_factor=self.quant_config.pack_factor,
|
|
|
+ weight_loader=weight_loader)
|
|
|
+
|
|
|
+ scales = GroupQuantScaleParameter(data=torch.empty(
|
|
|
+ num_groups,
|
|
|
+ output_size_per_partition,
|
|
|
+ dtype=params_dtype,
|
|
|
+ ),
|
|
|
+ input_dim=0,
|
|
|
+ output_dim=1,
|
|
|
+ weight_loader=weight_loader)
|
|
|
|
|
|
layer.register_parameter("qweight", qweight)
|
|
|
- set_weight_attrs(qweight, extra_weight_attrs)
|
|
|
layer.register_parameter("qzeros", qzeros)
|
|
|
- set_weight_attrs(qzeros, extra_weight_attrs)
|
|
|
layer.register_parameter("scales", scales)
|
|
|
- set_weight_attrs(scales, extra_weight_attrs)
|
|
|
|
|
|
layer.input_size_per_partition = input_size_per_partition
|
|
|
layer.output_size_per_partition = output_size_per_partition
|
|
@@ -224,6 +210,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|
|
# Here, we handle the repacking
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
|
device = layer.qweight.device
|
|
|
+ layer.qweight = torch.nn.Parameter(layer.qweight.data,
|
|
|
+ requires_grad=False)
|
|
|
+ layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
|
|
|
+ requires_grad=False)
|
|
|
+ layer.scales = torch.nn.Parameter(layer.scales.data,
|
|
|
+ requires_grad=False)
|
|
|
|
|
|
# Allocate marlin workspace
|
|
|
layer.workspace = marlin_make_workspace(
|