set_seed.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import numpy as np
  2. import torch
  3. import os
  4. import random
  5. # from transformers import set_seed
  6. def set_seed(seed: int = 0):
  7. """Set the seed
  8. seed = 0 Generate a random seed
  9. seed = -1 Disable deterministic algorithms
  10. 0 < seed < 2**32 Set the seed
  11. Args:
  12. seed: integer to use as seed
  13. Returns:
  14. integer used as seed
  15. """
  16. original_seed = seed
  17. # See for more informations: https://pytorch.org/docs/stable/notes/randomness.html
  18. if seed == -1:
  19. # Disable deterministic
  20. torch.backends.cudnn.deterministic = False
  21. torch.backends.cudnn.benchmark = True
  22. else:
  23. # Enable deterministic
  24. torch.backends.cudnn.deterministic = True
  25. torch.backends.cudnn.benchmark = False
  26. if seed <= 0:
  27. # Generate random seed
  28. # Use default_rng() because it is independent of np.random.seed()
  29. seed = np.random.default_rng().integers(1, 2**32 - 1)
  30. assert 0 < seed < 2**32
  31. np.random.seed(seed)
  32. random.seed(seed)
  33. torch.manual_seed(seed)
  34. torch.cuda.manual_seed_all(seed)
  35. os.environ["PYTHONHASHSEED"] = str(seed)
  36. return original_seed if original_seed != 0 else seed