__init__.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import torch
  2. _output_ref = None
  3. _replicas_ref = None
  4. def data_parallel_workaround(model, *input):
  5. global _output_ref
  6. global _replicas_ref
  7. device_ids = list(range(torch.cuda.device_count()))
  8. output_device = device_ids[0]
  9. replicas = torch.nn.parallel.replicate(model, device_ids)
  10. # input.shape = (num_args, batch, ...)
  11. inputs = torch.nn.parallel.scatter(input, device_ids)
  12. # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
  13. replicas = replicas[:len(inputs)]
  14. outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
  15. y_hat = torch.nn.parallel.gather(outputs, output_device)
  16. _output_ref = outputs
  17. _replicas_ref = replicas
  18. return y_hat
  19. class ValueWindow():
  20. def __init__(self, window_size=100):
  21. self._window_size = window_size
  22. self._values = []
  23. def append(self, x):
  24. self._values = self._values[-(self._window_size - 1):] + [x]
  25. @property
  26. def sum(self):
  27. return sum(self._values)
  28. @property
  29. def count(self):
  30. return len(self._values)
  31. @property
  32. def average(self):
  33. return self.sum / max(1, self.count)
  34. def reset(self):
  35. self._values = []