|
@@ -426,6 +426,10 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
|
|
def default_weight_loader(param: torch.Tensor,
|
|
|
loaded_weight: torch.Tensor) -> None:
|
|
|
"""Default weight loader."""
|
|
|
+ # If the weight on disk does not have a shape, give it one
|
|
|
+ # (such scales for AutoFp8).
|
|
|
+ if len(loaded_weight.shape) == 0:
|
|
|
+ loaded_weight = loaded_weight.reshape(1)
|
|
|
assert param.size() == loaded_weight.size()
|
|
|
param.data.copy_(loaded_weight)
|
|
|
|