|
@@ -112,7 +112,7 @@ def apply_fp8_linear(
|
|
|
|
|
|
|
|
|
|
|
|
- if bias is None and cutlass_fp8_supported:
|
|
|
+ if cutlass_fp8_supported:
|
|
|
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
|
|
|
|
|
|
|
|
@@ -120,7 +120,8 @@ def apply_fp8_linear(
|
|
|
weight,
|
|
|
out_dtype=input.dtype,
|
|
|
scale_a=x_scale,
|
|
|
- scale_b=weight_scale)
|
|
|
+ scale_b=weight_scale,
|
|
|
+ bias=bias)
|
|
|
|
|
|
else:
|
|
|
qinput, x_scale = ops.scaled_fp8_quant(input,
|