Ver código fonte

fix: torch.uniform_ doesn't support FP8, fix for dummy weights

AlpinDale 7 meses atrás
pai
commit
94ba676ee0
1 arquivos alterados com 8 adições e 1 exclusões
  1. 8 1
      aphrodite/modeling/model_loader/weight_utils.py

+ 8 - 1
aphrodite/modeling/model_loader/weight_utils.py

@@ -367,4 +367,11 @@ def initialize_dummy_weights(
     """
     for param in model.state_dict().values():
         if torch.is_floating_point(param):
-            param.data.uniform_(low, high)
+            if torch.finfo(param.data.dtype).bits < 16:
+                # uniform_ doesn't support < 16-bit datatypes (FP8)
+                dtype = param.data.dtype
+                tmp_param = param.data.to(torch.float16)
+                tmp_param = tmp_param.uniform_(low, high).to(dtype)
+                param.data.copy_(tmp_param)
+            else:
+                param.uniform_(low, high)