2
0

lr_scheduler.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334
  1. import numpy as np
  2. class LambdaWarmUpCosineScheduler:
  3. """
  4. note: use with a base_lr of 1.0
  5. """
  6. def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
  7. self.lr_warm_up_steps = warm_up_steps
  8. self.lr_start = lr_start
  9. self.lr_min = lr_min
  10. self.lr_max = lr_max
  11. self.lr_max_decay_steps = max_decay_steps
  12. self.last_lr = 0.
  13. self.verbosity_interval = verbosity_interval
  14. def schedule(self, n):
  15. if self.verbosity_interval > 0:
  16. if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
  17. if n < self.lr_warm_up_steps:
  18. lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
  19. self.last_lr = lr
  20. return lr
  21. else:
  22. t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
  23. t = min(t, 1.0)
  24. lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
  25. 1 + np.cos(t * np.pi))
  26. self.last_lr = lr
  27. return lr
  28. def __call__(self, n):
  29. return self.schedule(n)