123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- import torch
- _output_ref = None
- _replicas_ref = None
- def data_parallel_workaround(model, *input):
- global _output_ref
- global _replicas_ref
- device_ids = list(range(torch.cuda.device_count()))
- output_device = device_ids[0]
- replicas = torch.nn.parallel.replicate(model, device_ids)
- # input.shape = (num_args, batch, ...)
- inputs = torch.nn.parallel.scatter(input, device_ids)
- # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
- replicas = replicas[:len(inputs)]
- outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
- y_hat = torch.nn.parallel.gather(outputs, output_device)
- _output_ref = outputs
- _replicas_ref = replicas
- return y_hat
- class ValueWindow():
- def __init__(self, window_size=100):
- self._window_size = window_size
- self._values = []
- def append(self, x):
- self._values = self._values[-(self._window_size - 1):] + [x]
- @property
- def sum(self):
- return sum(self._values)
- @property
- def count(self):
- return len(self._values)
- @property
- def average(self):
- return self.sum / max(1, self.count)
- def reset(self):
- self._values = []
|