|
@@ -93,7 +93,7 @@ class MixtralMoE(nn.Module):
|
|
|
params_dtype=self.params_dtype,
|
|
|
quant_config=None)
|
|
|
|
|
|
- if self.use_fp8:
|
|
|
+ if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
|
|
params_dtype = torch.float8_e4m3fn
|
|
|
|
|
|
self.w13_weight = nn.Parameter(
|