Shortcuts

Source code for mmflow.datasets.samplers.distributed_sampler

# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Iterator, Optional, Sequence

import torch
import torch.distributed as dist
from torch.utils.data import Dataset
from torch.utils.data import DistributedSampler as _DistributedSampler
from torch.utils.data import Sampler

from mmflow.core.utils import sync_random_seed


[docs]class DistributedSampler(_DistributedSampler): """DistributedSampler inheriting from `torch.utils.data.DistributedSampler`. This distributed sampler is compatible Pytorch==1.5, as there is no `seed` argument in Pytorch==1.5. Args: datasets (Dataset): the dataset will be loaded. num_replicas (int, optional): Number of processes participating in distributed training. By default, world_size is retrieved from the current distributed group. rank (int, optional): Rank of the current process within num_replicas. By default, rank is retrieved from the current distributed group. shuffle (bool): If True (default), sampler will shuffle the indices. seed (int): random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be identical across all processes in the distributed group. Default: ``0``. """ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed=0) -> None: super().__init__( dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) # In distributed sampling, different ranks should sample # non-overlapped data in the dataset. Therefore, this function # is used to make sure that each rank shuffles the data indices # in the same order based on the same seed. Then different # ranks could use different indices to select non-overlapped # data from the same data list. self.seed = sync_random_seed(seed) def __iter__(self) -> Iterator: """ Yields: Iterator: iterator of indices for rank. """ # deterministically shuffle based on epoch if self.shuffle: g = torch.Generator() # When :attr:`shuffle=True`, this ensures all replicas # use a different random ordering for each epoch. # Otherwise, the next iteration of this sampler will # yield the same ordering. g.manual_seed(self.epoch + self.seed) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = torch.arange(len(self.dataset)).tolist() # add extra samples to make it evenly divisible indices += indices[:(self.total_size - len(indices))] assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] assert len(indices) == self.num_samples return iter(indices)
[docs]class MixedBatchDistributedSampler(Sampler): """Distributed Sampler for mixed data batch. Args: datasets (list): List of datasets will be loaded. sample_ratio (list): List of the ratio of each dataset in a batch, e.g. datasets=[DatasetA, DatasetB], sample_ratio=[0.25, 0.75], sample_per_gpu=1, gpus=8, it means 2 gpus load DatasetA, and 6 gpus load DatasetB. The length of datasets must be equal to length of sample_ratio. num_replicas (int, optional): Number of processes participating in distributed training. By default, world_size is retrieved from the current distributed group. rank (int, optional): Rank of the current process within num_replicas. By default, rank is retrieved from the current distributed group. shuffle (bool): If True (default), sampler will shuffle the indices. seed (int): random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be identical across all processes in the distributed group. Default: ``0``. """ def __init__(self, datasets: Sequence[Dataset], sample_ratio: Sequence[float], num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0) -> None: # base class `Sampler` do nothing in `__init__` function # super().__init__() assert len(datasets) == len(sample_ratio) assert sum(sample_ratio) == 1. if num_replicas is None: if not dist.is_available(): raise RuntimeError( 'Requires distributed package to be available') num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError( 'Requires distributed package to be available') rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError( f'Invalid rank {rank}, rank should be in the interval' f' [0, {num_replicas - 1}]') self.datasets = datasets self.num_replicas = num_replicas self.datasets_num_replicas = [ math.ceil(num_replicas * r) for r in sample_ratio ] self.cumulative_replicas = [] t = 0 for nr in self.datasets_num_replicas: t += nr self.cumulative_replicas.append(t) self.datasets_length = [len(d) for d in datasets] self.cumulative_datasets_length = [] t = 0 for dl in self.datasets_length: t += dl self.cumulative_datasets_length.append(t) # the smallest num_sample self.datasets_num_samples = [ math.ceil(length / ratio) for length, ratio in zip( self.datasets_length, self.datasets_num_replicas) ] self.num_samples = min(self.datasets_num_samples) # the dataset that decides the num_samples and total_size self.key_dataset = self.datasets_num_samples.index(self.num_samples) self.key_dataset_length = self.datasets_length[self.key_dataset] self.total_size = [ self.num_samples * nr for nr in self.datasets_num_replicas ] self.rank = rank self.epoch = 0 self.shuffle = shuffle # In distributed sampling, different ranks should sample # non-overlapped data in the dataset. Therefore, this function # is used to make sure that each rank shuffles the data indices # in the same order based on the same seed. Then different # ranks could use different indices to select non-overlapped # data from the same data list. self.seed = sync_random_seed(seed) def __iter__(self) -> Iterator: """ Yields: Iterator: iterator of indices for current rank. """ # datasets map different rank for dataset_idx, cumulative_replicas_ in enumerate( self.cumulative_replicas): if self.rank < cumulative_replicas_: break # deterministically shuffle each datasets based on epoch if self.shuffle: g = torch.Generator() # When :attr:`shuffle=True`, this ensures all replicas # use a different random ordering for each epoch. # Otherwise, the next iteration of this sampler will # yield the same ordering. g.manual_seed(self.epoch + self.seed) indices = torch.randperm( self.datasets_length[dataset_idx], generator=g).tolist() else: indices = torch.arange(self.datasets_length[dataset_idx]).tolist() if self.total_size[dataset_idx] > len(indices): # add extra samples for key_dataset to make it evenly divisible indices += indices[:(self.total_size[dataset_idx] - len(indices))] else: indices = indices[:self.total_size[dataset_idx]] assert len(indices) == self.total_size[dataset_idx] # subsample last_cumulative_replicas = 0 \ if dataset_idx == 0 else self.cumulative_replicas[dataset_idx - 1] indices = indices[( self.rank - last_cumulative_replicas ):self.total_size[dataset_idx]:self.datasets_num_replicas[dataset_idx]] assert len(indices) == self.num_samples # find the dataset for this rank last_cumulative_length = 0 \ if dataset_idx == 0 else \ self.cumulative_datasets_length[dataset_idx-1] indices = [idx + last_cumulative_length for idx in indices] return iter(indices) def __len__(self) -> int: """Get the combined dataset length.""" return self.num_samples
[docs] def set_epoch(self, epoch: int) -> None: r""" Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. Arguments: epoch (int): Epoch number. """ self.epoch = epoch
Read the Docs v: stable
Versions
latest
stable
1.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.