from typing import Union import torch def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): assert dst.dtype == src.dtype, "Tensors must have the same dtype" # update tensor shape and stride dst.as_strided_(src.shape, src.stride()) # If not the same underlying storage move tensor data if dst.data_ptr() != src.data_ptr(): dst.copy_(src) del src # Newly generated tensors need to replace existing tensors that are # already registered as parameters by Aphrodite (and won't be freed) def replace_parameter(mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter]) -> None: old = getattr(mod, name) if old.dtype == new.dtype and \ old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): # If we can just update in-place to avoid re-registering # can be faster if the underlying storage is the same update_tensor_inplace(old, new) else: # Fallback re-register parameter if not isinstance(new, torch.nn.Parameter): new = torch.nn.Parameter(new) mod.register_parameter(name, torch.nn.Parameter(new))