|
@@ -11,6 +11,8 @@ from aphrodite.common.utils import is_hip, print_warning_once
|
|
|
from aphrodite.modeling.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
|
|
from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
|
|
|
UnquantizedLinearMethod)
|
|
|
+from aphrodite.modeling.parameter import (ModelWeightParameter,
|
|
|
+ PerTensorScaleParameter)
|
|
|
from aphrodite.modeling.utils import set_weight_attrs
|
|
|
from aphrodite.platforms import current_platform
|
|
|
from aphrodite.quantization.base_config import (QuantizationConfig,
|
|
@@ -21,8 +23,7 @@ from aphrodite.quantization.utils.marlin_utils_fp8 import (
|
|
|
from aphrodite.quantization.utils.quant_utils import is_layer_skipped
|
|
|
from aphrodite.quantization.utils.w8a8_utils import (
|
|
|
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
|
|
- create_per_tensor_scale_param, cutlass_fp8_supported,
|
|
|
- normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
|
|
+ cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
|
|
requantize_with_max_scale)
|
|
|
|
|
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|
@@ -136,6 +137,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
|
):
|
|
|
del input_size, output_size
|
|
|
output_size_per_partition = sum(output_partition_sizes)
|
|
|
+ weight_loader = extra_weight_attrs.get("weight_loader")
|
|
|
|
|
|
layer.logical_widths = output_partition_sizes
|
|
|
|
|
@@ -147,34 +149,38 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
|
weight_dtype = (torch.float8_e4m3fn
|
|
|
if self.quant_config.is_checkpoint_fp8_serialized else
|
|
|
params_dtype)
|
|
|
- weight = Parameter(torch.empty(output_size_per_partition,
|
|
|
- input_size_per_partition,
|
|
|
- dtype=weight_dtype),
|
|
|
- requires_grad=False)
|
|
|
+ weight = ModelWeightParameter(data=torch.empty(
|
|
|
+ output_size_per_partition,
|
|
|
+ input_size_per_partition,
|
|
|
+ dtype=weight_dtype),
|
|
|
+ input_dim=1,
|
|
|
+ output_dim=0,
|
|
|
+ weight_loader=weight_loader)
|
|
|
layer.register_parameter("weight", weight)
|
|
|
- set_weight_attrs(weight, {
|
|
|
- **extra_weight_attrs,
|
|
|
- "input_dim": 1,
|
|
|
- "output_dim": 0,
|
|
|
- })
|
|
|
|
|
|
# If checkpoint is serialized fp8, load them.
|
|
|
# Otherwise, wait until process_weights_after_loading.
|
|
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
|
# WEIGHT SCALE
|
|
|
- scale = create_per_tensor_scale_param(output_partition_sizes,
|
|
|
- **extra_weight_attrs)
|
|
|
+ scale = PerTensorScaleParameter(data=torch.empty(
|
|
|
+ len(output_partition_sizes), dtype=torch.float32),
|
|
|
+ weight_loader=weight_loader)
|
|
|
+ scale[:] = torch.finfo(torch.float32).min
|
|
|
layer.register_parameter("weight_scale", scale)
|
|
|
|
|
|
# INPUT ACTIVATION SCALE
|
|
|
if self.quant_config.activation_scheme == "static":
|
|
|
- scale = create_per_tensor_scale_param(output_partition_sizes,
|
|
|
- **extra_weight_attrs)
|
|
|
+ scale = PerTensorScaleParameter(data=torch.empty(
|
|
|
+ len(output_partition_sizes), dtype=torch.float32),
|
|
|
+ weight_loader=weight_loader)
|
|
|
+ scale[:] = torch.finfo(torch.float32).min
|
|
|
layer.register_parameter("input_scale", scale)
|
|
|
else:
|
|
|
layer.register_parameter("input_scale", None)
|
|
|
|
|
|
def process_weights_after_loading(self, layer: Module) -> None:
|
|
|
+ layer.weight = torch.nn.Parameter(layer.weight.data,
|
|
|
+ requires_grad=False)
|
|
|
# If checkpoint not serialized fp8, quantize the weights.
|
|
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
|
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
|
@@ -196,6 +202,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
|
# If checkpoint is fp8, handle that there are N scales for N
|
|
|
# shards in a fused module
|
|
|
else:
|
|
|
+ layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
|
|
+ requires_grad=False)
|
|
|
+ if self.quant_config.activation_scheme == "static":
|
|
|
+ layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
|
|
+ requires_grad=False)
|
|
|
# If using marlin (w8a16), kernel uses channelwise weights,
|
|
|
# so extend the weight scales to be channelwise.
|
|
|
if self.use_marlin:
|