Sfoglia il codice sorgente

fix: weight loading for scalars (#718)

AlpinDale 6 mesi fa
parent
commit
8e22069c9e
1 ha cambiato i file con 11 aggiunte e 5 eliminazioni
  1. 11 5
      aphrodite/modeling/model_loader/weight_utils.py

+ 11 - 5
aphrodite/modeling/model_loader/weight_utils.py

@@ -502,11 +502,17 @@ def default_weight_loader(param: torch.Tensor,
     """Default weight loader."""
 
     try:
-        assert param.size() == loaded_weight.size(), (
-            f"Attempted to load weight ({loaded_weight.size()}) "
-            f"into parameter ({param.size()})")
-
-        param.data.copy_(loaded_weight)
+        if param.numel() == 1 and loaded_weight.numel() == 1:
+            # Sometimes scalar values aren't considered tensors with shapes
+            # so if both param and loaded_weight are a scalar,
+            # "broadcast" instead of copy
+            param.data.fill_(loaded_weight.item())
+        else:
+            assert param.size() == loaded_weight.size(), (
+                f"Attempted to load weight ({loaded_weight.size()}) "
+                f"into parameter ({param.size()})")
+
+            param.data.copy_(loaded_weight)
     except Exception:
         # NOTE: This exception is added for the purpose of setting breakpoint to
         # debug weight loading issues.