accelerate
accelerate copied to clipboard
Dataloader WeightedRandomSampler + Distributed Training
System Info
accelerate 0.31.0
Ubuntu 22.04 (WSL)
python=3.10.14
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - [X] My own task or dataset (give details below)
Reproduction
I would like to combine distributed training and a weighted random sampler. In order to do that, i :
- Create my Dataset inheriting from torch.utils.data.Dataset
- Compute weights specific to my classes and data
- Create my DataLoader with the random sampler
- Prepare my dataloader with accelerate
But it seems that this is not working because we have data leaks between processes.
I would like to make sure, processes uses different data, like that :
I developped an example script in order to understand the process :
from accelerate import Accelerator
import argparse
import os
import torch.distributed as dist
import torch
from tqdm.auto import tqdm
from torch.utils.data import Dataset,DataLoader
from torch.utils.data import WeightedRandomSampler,BatchSampler
WORLD_SIZE = int(os.getenv('WORLD_SIZE',1))
MAIN_PROCESS = not int(os.getenv('RANK',0))
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_count',default=12800)
parser.add_argument('--epochs',default=20)
parser.add_argument('--batch_size',default=64)
parser.add_argument('--balance',action='store_true',default=False)
def is_even(number):
return not number%2 # example 10 => 10%2 == 0
class DummyDataset(Dataset):
def __init__(self,dataset_count:int):
self.data = range(dataset_count)
def __len__(self):
return len(self.data)
def dataloader(self,batch_size,balance:bool=False,seed=42,batch_sampler=False,drop_last:bool=False):
generator = torch.Generator().manual_seed(seed)
def get_weight(num):
if is_even(num):
# even
return 1.0
else:
# odd (impair)
return 0.1
if balance:
weights = [get_weight(i) for i in self.data]
sampler = WeightedRandomSampler(weights,len(self),replacement=True,generator=generator)
else:
sampler = None
if batch_sampler:
return DataLoader(self,batch_sampler=BatchSampler(sampler,batch_size,drop_last))
else:
return DataLoader(self,batch_size,sampler=sampler,drop_last=drop_last)
def __getitem__(self,idx):
row_index = self.data[idx]
return row_index
def main(
dataset_count:int,
epochs:int,
batch_size:int,
balance:bool=True):
if int(os.environ.get('WORLD_SIZE',1))>1:
dist.init_process_group(backend='gloo')
accelerator = Accelerator(cpu=True)
# We mount the right storage...
# We get the path
dataset = DummyDataset(dataset_count)
# Dataloader without Accelerate...
dataloader = dataset.dataloader(batch_size,balance)
batched_data = []
if MAIN_PROCESS:
print(f'Running {epochs*len(dataloader)} iterations')
for epoch in range(epochs):
for batch in dataloader:
batch:torch.Tensor
batched_data.extend(batch.tolist())
count_even = len([v for v in batched_data if is_even(v)])
count_odd = len([v for v in batched_data if not is_even(v)])
ratio_odd = count_odd/(count_even+count_odd)
if MAIN_PROCESS:
print('Get proportion of Odd data without accelerate')
print(f'Ratio Odd = {ratio_odd}')
# Dataloader with Accelerate...
dataloader = accelerator.prepare(dataloader)
# We increase learning rate when multiGPU
batched_data = []
if MAIN_PROCESS:
print(f'Running {epochs*len(dataloader)} iterations')
for epoch in range(epochs):
for batch in dataloader:
batch:torch.Tensor
batched_data.extend(batch.tolist())
count_even = len([v for v in batched_data if is_even(v)])
count_odd = len([v for v in batched_data if not is_even(v)])
ratio_odd = count_odd/(count_even+count_odd)
if MAIN_PROCESS:
print('Get proportion of Odd data with accelerate')
print(f'Ratio Odd = {ratio_odd}')
# We save to a file for further processing...
suffix = '_balanced' if balance else '_unbalanced'
rank = str(os.environ.get('RANK',0))
with open(f'test_{rank}{suffix}.json','w') as jsf:
import json
json.dump(sorted(batched_data),jsf,indent=4)
accelerator.wait_for_everyone()
seen_data = set(batched_data)
if WORLD_SIZE>1:
# Now every one will open the other...
other_rank = str(int(not int(os.environ.get('RANK',0))))
with open(f'test_{other_rank}{suffix}.json','r') as jsf:
import json
other_data = json.load(jsf)
# We get unique ids in order to check that we don't have leaks...
other_data = set(other_data)
batched_data = set(batched_data)
unique_in_rank = batched_data.difference(other_data)
if MAIN_PROCESS:
print('Verify the unicity of the data on each rank...\n')
print(f'{len(unique_in_rank)}/{len(batched_data)} data only are not leaking from rank {rank} to rank {other_rank}')
seen_data = unique_in_rank.union(other_data)
# Unseen data
unseen_data = set(dataset.data).difference(seen_data)
if MAIN_PROCESS:
print("Unseen Data")
print(f'{len(unseen_data)}/{len(dataset)} have not been seen...')
if __name__=='__main__':
params = vars(parser.parse_args())
print('----------------------------------------')
[print(f'{k}: {v}') for k,v in params.items()]
print('----------------------------------------')
main(**params)
You can try to run this script different ways :
Single node without "balance"
----------------------------------------
dataset_count: 12800
epochs: 20
batch_size: 64
balance: False
----------------------------------------
Running 4000 iterations
Get proportion of Odd data without accelerate
Ratio Odd = 0.5
Running 4000 iterations
Get proportion of Odd data with accelerate
Ratio Odd = 0.5
Unseen Data
0/12800 have not been seen...
Multiple node (2) without "balance"
----------------------------------------
dataset_count: 12800
epochs: 20
batch_size: 64
balance: False
----------------------------------------
Running 4000 iterations
Get proportion of Odd data without accelerate
Ratio Odd = 0.5
Running 2000 iterations
Get proportion of Odd data with accelerate
Ratio Odd = 0.5
Verify the unicity of the data on each rank...
Verify the unicity of the data on each rank...
6400/6400 data only are not leaking from rank 0 to rank 1
6400/6400 data only are not leaking from rank 1 to rank 0
Unseen Data
0/12800 have not been seen...
We see that we do not have any leak, all data are seen.
Single node with "balance"
----------------------------------------
dataset_count: 12800
epochs: 20
batch_size: 64
balance: True
----------------------------------------
Running 4000 iterations
Get proportion of Odd data without accelerate
Ratio Odd = 0.09179296875
Running 4000 iterations
Get proportion of Odd data with accelerate
Ratio Odd = 0.09139453125
Unseen Data
167/12800 have not been seen...
We see that a few data has not been seen. It's normal because we have a very low rate of Odd data.
Multiple node with "balance"
----------------------------------------
dataset_count: 12800
epochs: 20
batch_size: 64
balance: True
----------------------------------------
Running 4000 iterations
Get proportion of Odd data without accelerate
Ratio Odd = 0.09179296875
Running 2000 iterations
Get proportion of Odd data with accelerate
Ratio Odd = 0.0917890625
Verify the unicity of the data on each rank...
895/11760 data only are not leaking from rank 0 to rank 1
873/11738 data only are not leaking from rank 1 to rank 0
Unseen Data
167/12800 have not been seen...
We see that data are leaking from one node to the other. Like if there was an issue with the distributed sampler. How to fix it ?
Expected behavior
I would like the weighted sampler to be used and i would like nothing to leak from node 1 to node 2 like in the case where we don't have weighted sampler.
Do you have any idea about how to get this result ?
Thanks !
PyTorch currently doesn't support this:
https://github.com/pytorch/pytorch/issues/77154
https://github.com/pytorch/pytorch/issues/23430
So at this time we don't plan on implementing this, until they have support underneath
(Labeling as enhancement and feature request so this can stay open)
Thanks for the answer.
I saw in reference you send me that we can use some proxy sampler that can do this :
from torch.utils.data.distributed import DistributedSampler
class DistributedProxySampler(DistributedSampler):
"""Sampler that restricts data loading to a subset of input sampler indices.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Input sampler is assumed to be of constant size.
Arguments:
sampler: Input data sampler.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
"""
I think i can easily do something like that, to "eat" my original sampler and get only weights/data from a subset of the dataset.
The problem if i do this is that i can no longer use accelerator.prepare() because it would divide a second time the number of iteration by "num_processes".
Let's imagine i only prepare my model/scheduler/optimizers.
I like the ability of accelerate.prepare to move the data on the right device. Is this operation done on dataloader side or on model side ?
If it is directly in the forward i can use a custom wrapper and bypass accelerate.prepare for my dataloader. Thanks
If you do the sampling yourself, you can build a accelerate.dataloader.DispatchDataLoader instead of a DataLoader and pass in everything you would normally I believe.
Otherwise I can quickly spin up a DeviceDataLoader that basically will not do any distributed sampling etc, and instead leave that to you. It will simply move the data to the right device.
That's not completely clear for me how works Dispatch Dataloader.
I have coded this custom DistributedWeightedRandomSampler that seems to respect the process exclusivity with accelerate.prepare() :
import torch
from torch.utils.data import Sampler
from torch import Tensor
import torch.distributed as dist
from collections.abc import Sequence,Iterator
class DistributedWeightedRandomSampler(Sampler[int]):
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
Args:
weights (sequence) : a sequence of weights, not necessary summing up to one
num_samples (int): number of samples to draw
replacement (bool): if ``True``, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a
sample index is drawn for a row, it cannot be drawn again for that row.
generator (Generator): Generator used in sampling.
Example:
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
"""
weights: Tensor
num_samples: int
replacement: bool
def __init__(self, weights: Sequence[float], num_samples: int,
replacement: bool = True,generator:torch.Generator=None,seed:int=42) -> None:
if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
num_samples <= 0:
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}")
if not isinstance(replacement, bool):
raise ValueError(f"replacement should be a boolean value, but got replacement={replacement}")
# We generate a random permutation of indices.
self.indices = torch.randperm(num_samples,generator=generator)
# We generate weight tensor
weights_tensor = torch.as_tensor(weights, dtype=torch.double)[self.indices]
if len(weights_tensor.shape) != 1:
raise ValueError("weights should be a 1d sequence but given "
f"weights have shape {tuple(weights_tensor.shape)}")
self.mask = torch.ones_like(weights_tensor).bool()
if dist.is_initialized():
num_processes = dist.get_world_size()
if num_processes>1:
assert generator is not None,"A generator should be set when num_processes > 1"
# We reset the mask to zero for all processes
self.mask = torch.zeros_like(weights_tensor)
# We want the mask to select only indices for the current process
# => We cut our indices in num_processes parts and we set the mask to 1 where the rank is matching
rank_indices = [i for i in range(len(self.mask)) if i%num_processes==dist.get_rank()]
self.mask[rank_indices]=1
self.mask=self.mask.bool()
else:
num_processes=1
# Set parameters...
self.weights = weights_tensor
self.num_samples = num_samples
self.replacement = replacement
self.generator = generator
def __iter__(self) -> Iterator[int]:
# We sample "num_samples" indices from the weights tensor "masked" on current process weights
rand_tensor = torch.multinomial(self.weights[self.mask], self.num_samples, self.replacement, generator=self.generator)
# We get corresponding indices
rank_indices = self.indices[self.mask]
rand_indices = rank_indices[rand_tensor]
rand_indices:torch.Tensor
# We sample only from theses indices.
yield from iter(rand_indices.tolist())
def __len__(self) -> int:
return self.num_samples
If i replace the default WeightedRandomSampler with that one, it seems to have the right behavior on 1 epoch, even with accelerate.prepare() !
Running 200 iterations undistributed
Get proportion of Odd data without accelerate
Ratio Odd = 0.094453125
Running 100 iterations on 2 ranks
Get proportion of Odd data with accelerate
Ratio Odd = 0.083125
Verify the unicity of the data on each rank...
0.0% data are leaking from rank 0 to rank 1
0.0% data are leaking from rank 1 to rank 0
Unseen Data
6429/12800 have not been seen...
If i increase the number of epoch to 20, it seems to work also :
----------------------------------------
dataset_count: 12800
epochs: 20
batch_size: 64
balance: True
----------------------------------------
Running 4000 iterations undistributed
Get proportion of Odd data without accelerate
Ratio Odd = 0.090625
Running 2000 iterations on 2 ranks
Get proportion of Odd data with accelerate
Ratio Odd = 0.090046875
Verify the unicity of the data on each rank...
0.0% data are leaking from rank 1 to rank 0
0.0% data are leaking from rank 0 to rank 1
Unseen Data
184/12800 have not been seen...
How does it works
I hope (it's not yet enough tested), it creates on the fly subsets of indices, one per process in order to iterate on them instead of the dataset. Each set of indice is process dependant but the lenght of the sampler remains the same to ensure accelerate will not reduce the total number of iterations.
Do you see some side effect that i could have miss ?