utils.py 928 B

1234567891011121314151617181920212223242526272829303132333435
  1. """Utils for model executor."""
  2. import random
  3. from typing import Any, Dict, Optional
  4. import numpy as np
  5. import torch
  6. def set_random_seed(seed: int) -> None:
  7. random.seed(seed)
  8. np.random.seed(seed)
  9. torch.manual_seed(seed)
  10. if torch.cuda.is_available():
  11. torch.cuda.manual_seed_all(seed)
  12. def set_weight_attrs(
  13. weight: torch.Tensor,
  14. weight_attrs: Optional[Dict[str, Any]],
  15. ):
  16. """Set attributes on a weight tensor.
  17. This method is used to set attributes on a weight tensor. This method
  18. will not overwrite existing attributes.
  19. Args:
  20. weight: The weight tensor.
  21. weight_attrs: A dictionary of attributes to set on the weight tensor.
  22. """
  23. if weight_attrs is None:
  24. return
  25. for key, value in weight_attrs.items():
  26. assert not hasattr(
  27. weight, key), (f"Overwriting existing tensor attribute: {key}")
  28. setattr(weight, key, value)