12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- import numpy as np
- import torch
- import os
- import random
- # from transformers import set_seed
- def set_seed(seed: int = 0):
- """Set the seed
- seed = 0 Generate a random seed
- seed = -1 Disable deterministic algorithms
- 0 < seed < 2**32 Set the seed
- Args:
- seed: integer to use as seed
- Returns:
- integer used as seed
- """
- original_seed = seed
- # See for more informations: https://pytorch.org/docs/stable/notes/randomness.html
- if seed == -1:
- # Disable deterministic
- torch.backends.cudnn.deterministic = False
- torch.backends.cudnn.benchmark = True
- else:
- # Enable deterministic
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- if seed <= 0:
- # Generate random seed
- # Use default_rng() because it is independent of np.random.seed()
- seed = np.random.default_rng().integers(1, 2**32 - 1)
- assert 0 < seed < 2**32
- np.random.seed(seed)
- random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- os.environ["PYTHONHASHSEED"] = str(seed)
- return original_seed if original_seed != 0 else seed
|