|
@@ -502,11 +502,17 @@ def default_weight_loader(param: torch.Tensor,
|
|
|
"""Default weight loader."""
|
|
|
|
|
|
try:
|
|
|
- assert param.size() == loaded_weight.size(), (
|
|
|
- f"Attempted to load weight ({loaded_weight.size()}) "
|
|
|
- f"into parameter ({param.size()})")
|
|
|
-
|
|
|
- param.data.copy_(loaded_weight)
|
|
|
+ if param.numel() == 1 and loaded_weight.numel() == 1:
|
|
|
+ # Sometimes scalar values aren't considered tensors with shapes
|
|
|
+ # so if both param and loaded_weight are a scalar,
|
|
|
+ # "broadcast" instead of copy
|
|
|
+ param.data.fill_(loaded_weight.item())
|
|
|
+ else:
|
|
|
+ assert param.size() == loaded_weight.size(), (
|
|
|
+ f"Attempted to load weight ({loaded_weight.size()}) "
|
|
|
+ f"into parameter ({param.size()})")
|
|
|
+
|
|
|
+ param.data.copy_(loaded_weight)
|
|
|
except Exception:
|
|
|
# NOTE: This exception is added for the purpose of setting breakpoint to
|
|
|
# debug weight loading issues.
|