123456789101112131415161718192021222324252627282930313233 |
- from typing import Union
- import torch
- def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor):
- assert dst.dtype == src.dtype, "Tensors must have the same dtype"
- # update tensor shape and stride
- dst.as_strided_(src.shape, src.stride())
- # If not the same underlying storage move tensor data
- if dst.data_ptr() != src.data_ptr():
- dst.copy_(src)
- del src
- # Newly generated tensors need to replace existing tensors that are
- # already registered as parameters by Aphrodite (and won't be freed)
- def replace_parameter(mod: torch.nn.Module, name: str,
- new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
- old = getattr(mod, name)
- if old.dtype == new.dtype and \
- old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
- # If we can just update in-place to avoid re-registering
- # can be faster if the underlying storage is the same
- update_tensor_inplace(old, new)
- else:
- # Fallback re-register parameter
- if not isinstance(new, torch.nn.Parameter):
- new = torch.nn.Parameter(new)
- mod.register_parameter(name, torch.nn.Parameter(new))
|