|
@@ -21,13 +21,17 @@ def replace_parameter(mod: torch.nn.Module, name: str,
|
|
new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
|
|
new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
|
|
|
|
|
|
old = getattr(mod, name)
|
|
old = getattr(mod, name)
|
|
- if old.dtype == new.dtype and \
|
|
|
|
|
|
+ if type(old) is type(new) and old.dtype == new.dtype and \
|
|
old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
|
|
old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
|
|
# If we can just update in-place to avoid re-registering
|
|
# If we can just update in-place to avoid re-registering
|
|
# can be faster if the underlying storage is the same
|
|
# can be faster if the underlying storage is the same
|
|
update_tensor_inplace(old, new)
|
|
update_tensor_inplace(old, new)
|
|
else:
|
|
else:
|
|
- # Fallback re-register parameter
|
|
|
|
|
|
+ # Fallback re-register parameter, convert to Parameter if necessary
|
|
|
|
+ # this not only ensures we don't register a tensor as a parameter, but
|
|
|
|
+ # also ensures that all parameter subclasses get re-registered as
|
|
|
|
+ # parameters for `torch.compile` compatibility
|
|
if not isinstance(new, torch.nn.Parameter):
|
|
if not isinstance(new, torch.nn.Parameter):
|
|
- new = torch.nn.Parameter(new)
|
|
|
|
- mod.register_parameter(name, torch.nn.Parameter(new))
|
|
|
|
|
|
+ new = torch.nn.Parameter(new, requires_grad=False)
|
|
|
|
+ mod.register_parameter(name,
|
|
|
|
+ torch.nn.Parameter(new, requires_grad=False))
|