layer_utils.py 1.2 KB

123456789101112131415161718192021222324252627282930313233
  1. from typing import Union
  2. import torch
  3. def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor):
  4. assert dst.dtype == src.dtype, "Tensors must have the same dtype"
  5. # update tensor shape and stride
  6. dst.as_strided_(src.shape, src.stride())
  7. # If not the same underlying storage move tensor data
  8. if dst.data_ptr() != src.data_ptr():
  9. dst.copy_(src)
  10. del src
  11. # Newly generated tensors need to replace existing tensors that are
  12. # already registered as parameters by Aphrodite (and won't be freed)
  13. def replace_parameter(mod: torch.nn.Module, name: str,
  14. new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
  15. old = getattr(mod, name)
  16. if old.dtype == new.dtype and \
  17. old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
  18. # If we can just update in-place to avoid re-registering
  19. # can be faster if the underlying storage is the same
  20. update_tensor_inplace(old, new)
  21. else:
  22. # Fallback re-register parameter
  23. if not isinstance(new, torch.nn.Parameter):
  24. new = torch.nn.Parameter(new)
  25. mod.register_parameter(name, torch.nn.Parameter(new))