initialize.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. #!/usr/bin/env python3
  2. """Initialize modules for espnet2 neural networks."""
  3. import torch
  4. from typeguard import check_argument_types
  5. def initialize(model: torch.nn.Module, init: str):
  6. """Initialize weights of a neural network module.
  7. Parameters are initialized using the given method or distribution.
  8. Custom initialization routines can be implemented into submodules
  9. as function `espnet_initialization_fn` within the custom module.
  10. Args:
  11. model: Target.
  12. init: Method of initialization.
  13. """
  14. assert check_argument_types()
  15. print("init with", init)
  16. # weight init
  17. for p in model.parameters():
  18. if p.dim() > 1:
  19. if init == "xavier_uniform":
  20. torch.nn.init.xavier_uniform_(p.data)
  21. elif init == "xavier_normal":
  22. torch.nn.init.xavier_normal_(p.data)
  23. elif init == "kaiming_uniform":
  24. torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
  25. elif init == "kaiming_normal":
  26. torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
  27. else:
  28. raise ValueError("Unknown initialization: " + init)
  29. # bias init
  30. for name, p in model.named_parameters():
  31. if ".bias" in name and p.dim() == 1:
  32. p.data.zero_()