123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py
- import itertools
- import math
- import random
- from random import shuffle
- from typing import Iterator
- from typing import Optional
- from typing import TypeVar
- import torch
- import torch.distributed as dist
- from torch.utils.data import Dataset
- from torch.utils.data import Sampler
- __all__ = [
- "DistributedBucketSampler",
- ]
- T_co = TypeVar("T_co", covariant=True)
- class DistributedBucketSampler(Sampler[T_co]):
- r"""
- sort the dataset wrt. input length
- divide samples into buckets
- sort within buckets
- divide buckets into batches
- sort batches
- """
- def __init__(
- self,
- dataset: Dataset,
- num_replicas: Optional[int] = None,
- rank: Optional[int] = None,
- shuffle: bool = True,
- seed: int = 0,
- drop_last: bool = False,
- batch_size: int = 32,
- ) -> None:
- if num_replicas is None:
- if not dist.is_available():
- raise RuntimeError("Requires distributed package to be available")
- num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
- if rank is None:
- if not dist.is_available():
- raise RuntimeError("Requires distributed package to be available")
- rank = dist.get_rank() if torch.cuda.is_available() else 0
- if torch.cuda.is_available():
- torch.cuda.set_device(rank)
- if rank >= num_replicas or rank < 0:
- raise ValueError(
- "Invalid rank {}, rank should be in the interval"
- " [0, {}]".format(rank, num_replicas - 1)
- )
- self.dataset = dataset
- self.num_replicas = num_replicas
- self.rank = rank
- self.epoch = 0
- self.drop_last = drop_last
- # If the dataset length is evenly divisible by # of replicas, then there
- # is no need to drop any data, since the dataset will be split equally.
- if (
- self.drop_last and len(self.dataset) % self.num_replicas != 0
- ): # type: ignore[arg-type]
- # Split to nearest available length that is evenly divisible.
- # This is to ensure each rank receives the same amount of data when
- # using this Sampler.
- self.num_samples = math.ceil(
- (len(self.dataset) - self.num_replicas)
- / self.num_replicas # type: ignore[arg-type]
- )
- else:
- self.num_samples = math.ceil(
- len(self.dataset) / self.num_replicas
- ) # type: ignore[arg-type]
- self.total_size = self.num_samples * self.num_replicas
- self.shuffle = shuffle
- self.seed = seed
- self.batch_size = batch_size
- self.id_with_length = self._get_sample_lengths()
- self.id_buckets = self.make_buckets(bucket_width=2.0)
- def _get_sample_lengths(self):
- id_with_lengths = []
- for i in range(len(self.dataset)):
- id_with_lengths.append((i, self.dataset.get_sample_length(i)))
- id_with_lengths.sort(key=lambda x: x[1])
- return id_with_lengths
- def make_buckets(self, bucket_width: float = 2.0):
- buckets = []
- cur = []
- max_sec = bucket_width
- for id, sec in self.id_with_length:
- if sec < max_sec:
- cur.append(id)
- else:
- buckets.append(cur)
- cur = [id]
- max_sec += bucket_width
- if len(cur) > 0:
- buckets.append(cur)
- return buckets
- def __iter__(self) -> Iterator[T_co]:
- if self.shuffle:
- # deterministically shuffle based on epoch and seed
- g = torch.Generator()
- g.manual_seed(self.seed + self.epoch)
- random.seed(self.epoch + self.seed)
- shuffled_bucket = []
- for buc in self.id_buckets:
- buc_copy = buc.copy()
- shuffle(buc_copy)
- shuffled_bucket.append(buc_copy)
- grouped_batch_size = self.batch_size * self.num_replicas
- shuffled_bucket = list(itertools.chain(*shuffled_bucket))
- n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
- batches = [
- shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
- for b in range(n_batch)
- ]
- shuffle(batches)
- indices = list(itertools.chain(*batches))
- else:
- # type: ignore[arg-type]
- indices = list(range(len(self.dataset)))
- if not self.drop_last:
- # add extra samples to make it evenly divisible
- padding_size = self.total_size - len(indices)
- if padding_size <= len(indices):
- indices += indices[:padding_size]
- else:
- indices += (indices * math.ceil(padding_size / len(indices)))[
- :padding_size
- ]
- else:
- # remove tail of data to make it evenly divisible.
- indices = indices[: self.total_size]
- assert len(indices) == self.total_size
- # subsample
- indices = indices[self.rank : self.total_size : self.num_replicas]
- assert len(indices) == self.num_samples
- return iter(indices)
- def __len__(self) -> int:
- return self.num_samples
- def set_epoch(self, epoch: int) -> None:
- r"""
- Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
- use a different random ordering for each epoch. Otherwise, the next iteration of this
- sampler will yield the same ordering.
- Args:
- epoch (int): Epoch number.
- """
- self.epoch = epoch
|