| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | import math |
| | from typing import Iterator, Optional, Sequence, TypeVar |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from torch.utils.data import Dataset, Sampler |
| |
|
| | __all__ = ["DistributedWeightedSampler"] |
| |
|
| | T_co = TypeVar("T_co", covariant=True) |
| |
|
| |
|
| | class DistributedWeightedSampler(Sampler[T_co]): |
| | def __init__( |
| | self, |
| | dataset: Dataset, |
| | weights: Sequence[float], |
| | num_samples: int, |
| | num_replicas: Optional[int] = None, |
| | rank: Optional[int] = None, |
| | shuffle: bool = True, |
| | seed: int = 0, |
| | drop_last: bool = False, |
| | ) -> None: |
| | if not isinstance(num_samples, int) or isinstance(num_samples, bool) or num_samples <= 0: |
| | raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}") |
| |
|
| | weights_tensor = torch.as_tensor(weights, dtype=torch.float) |
| | if len(weights_tensor.shape) != 1: |
| | raise ValueError( |
| | "weights should be a 1d sequence but given " f"weights have shape {tuple(weights_tensor.shape)}" |
| | ) |
| |
|
| | self.weights = weights_tensor |
| | self.num_samples = num_samples |
| |
|
| | if num_replicas is None: |
| | if not dist.is_available(): |
| | raise RuntimeError("Requires distributed package to be available") |
| | num_replicas = dist.get_world_size() |
| | if rank is None: |
| | if not dist.is_available(): |
| | raise RuntimeError("Requires distributed package to be available") |
| | rank = dist.get_rank() |
| | if rank >= num_replicas or rank < 0: |
| | raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") |
| | self.dataset = dataset |
| | self.num_replicas = num_replicas |
| | self.rank = rank |
| | self.epoch = 0 |
| | self.drop_last = drop_last |
| | self.shuffle = shuffle |
| |
|
| | if self.shuffle: |
| | self.num_samples = int(math.ceil(self.num_samples / self.num_replicas)) |
| | else: |
| | |
| |
|
| | |
| | |
| | if self.drop_last and len(self.dataset) % self.num_replicas != 0: |
| | |
| | |
| | |
| | self.num_samples = math.ceil( |
| | (len(self.dataset) - self.num_replicas) / self.num_replicas |
| | ) |
| | else: |
| | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) |
| |
|
| | self.total_size = self.num_samples * self.num_replicas |
| | self.shuffle = shuffle |
| | self.seed = seed |
| |
|
| | def __iter__(self) -> Iterator[T_co]: |
| | if self.shuffle: |
| | |
| | g = torch.Generator() |
| | g.manual_seed(self.seed + self.epoch) |
| | indices = torch.multinomial(input=self.weights, num_samples=self.total_size, replacement=True, generator=g).tolist() |
| | else: |
| | |
| | indices = list(range(len(self.dataset))) |
| | if not self.drop_last: |
| | |
| | padding_size = self.total_size - len(indices) |
| | if padding_size <= len(indices): |
| | indices += indices[:padding_size] |
| | else: |
| | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] |
| | else: |
| | |
| | indices = indices[: self.total_size] |
| | assert len(indices) == self.total_size |
| |
|
| | |
| | indices = indices[self.rank : self.total_size : self.num_replicas] |
| | assert len(indices) == self.num_samples |
| |
|
| | return iter(indices) |
| |
|
| | def __len__(self) -> int: |
| | return self.num_samples |
| |
|
| | def set_epoch(self, epoch: int) -> None: |
| | self.epoch = epoch |
| |
|