data
data copied to clipboard
What does it mean for a DataPipe to be 'replicable'?
📚 The doc issue
In the ReadingService docs the different sharding options and that one applies to replicable and one to non-replicable datapipes, but it's not really explained what that means.
Indirectly related, I'm also confused by the names ShardingRoundRobinDispatcher
and ShardingFilter
. The docs for ShardingFilter
say
each instance of the DataPipe (on different workers) will have every n-th element of the original DataPipe, where n equals to the number of instances.
Is that not essentially the definition of round robin distribution? How is that different than what the the DataPipes downstream of a ShardingRoundRobinDispatcher
on different workers receive?
Suggest a potential alternative/fix
Clarify more the difference between ShardingRoundRobinDispatcher
and ShardingFilter
and explain what 'replicable' means in that context.
Possibly consider renaming ShardingRoundRobinDispatcher
and ShardingFilter
, if the answers to my questions above are 'yes' to something more meaningful.
replicable
means the DataPipe
can be copied multiple times for multiprocessing workers. If it's not, it will be either kept in a dispatching process when ShardingRoundRobinDispatcher
is used or kept in the main process at the end connected to all replicated DataPipes from each worker process
I agree the docs for ShardingFilter
and ShardingRoundRobinDispatcher
are confusing.
I created code (see further below for the code and example output) to test my understanding of ShardingRoundRobinDispatcher
. Based on the ShardingRoundRobinDispatcher docs, I expected that
dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(increment)
would result in increment
running on each worker process, but that
dp.map(increment).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
would result in increment
running on a single dispatching process totally separate from the DataLoader
worker processes. But as you can see in the output below, increment
is still being called across multiple processes.
Code:
import os
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
def increment(x):
print(f"processs ID {os.getpid()} called increment for {x+1}")
return x + 1
def create_datapipe_round_robin_before_increment(i):
dp = IterableWrapper(range(i))
dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(increment)
return dp
def create_datapipe_round_robin_after_increment(i):
dp = IterableWrapper(range(i))
dp = dp.map(increment).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
return dp
if __name__ == "__main__":
print(f"parent process ID: {os.getpid()}")
N = 5
print("sharding_round_robin_dispatch BEFORE map(increment):")
dp1 = create_datapipe_round_robin_before_increment(N)
for data in DataLoader(dp1, num_workers=2):
print(int(data))
print()
print("sharding_round_robin_dispatch AFTER map(increment):")
dp2 = create_datapipe_round_robin_after_increment(N)
for data in DataLoader(dp2, num_workers=2):
print(int(data))
Output with torchdata 0.6.1 on MacOS Ventura 13.3.1 (Intel):
$ python shard_round_robin_test.py
parent process ID: 13495
sharding_round_robin_dispatch BEFORE map(increment):
processs ID 13502 called increment for 1
processs ID 13503 called increment for 1
processs ID 13503 called increment for 2
processs ID 13502 called increment for 2
1
processs ID 13502 called increment for 3
1
processs ID 13503 called increment for 3
2
processs ID 13502 called increment for 4
2
processs ID 13503 called increment for 4
3
processs ID 13502 called increment for 5
3
processs ID 13503 called increment for 5
4
4
5
5
sharding_round_robin_dispatch AFTER map(increment):
processs ID 13510 called increment for 1
processs ID 13511 called increment for 1
processs ID 13510 called increment for 2
processs ID 13511 called increment for 2
1
processs ID 13510 called increment for 3
1
processs ID 13511 called increment for 3
2
processs ID 13510 called increment for 4
2
processs ID 13511 called increment for 4
3
processs ID 13510 called increment for 5
3
processs ID 13511 called increment for 5
4
4
5
5
What's the expected behavior here?
@JohnHBrock I think you need to be using torchdata's DataLoader2, not DataLoader.
@lendle You're right, I just tested and it works with DataLoader2
. Here's the DataLoader2
version of the above code for comparison:
import os
from torchdata.datapipes.iter import IterableWrapper
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
def increment(x):
print(f"processs ID {os.getpid()} called increment for {x+1}")
return x + 1
def create_datapipe_round_robin_before_increment(i):
dp = IterableWrapper(range(i))
dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(increment)
return dp
def create_datapipe_round_robin_after_increment(i):
dp = IterableWrapper(range(i))
dp = dp.map(increment).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
return dp
if __name__ == "__main__":
print(f"parent process ID: {os.getpid()}")
N = 5
print("sharding_round_robin_dispatch BEFORE map(increment):")
dp1 = create_datapipe_round_robin_before_increment(N)
mp_reading_service = MultiProcessingReadingService(num_workers=2)
for data in DataLoader2(dp1, reading_service=mp_reading_service):
print(int(data))
print()
print("sharding_round_robin_dispatch AFTER map(increment):")
dp2 = create_datapipe_round_robin_after_increment(N)
mp_reading_service = MultiProcessingReadingService(num_workers=2)
for data in DataLoader2(dp2, reading_service=mp_reading_service):
print(int(data))
and here's the output I see:
parent process ID: 88637
sharding_round_robin_dispatch BEFORE map(increment):
processs ID 88646 called increment for 2
processs ID 88645 called increment for 1
processs ID 88646 called increment for 4
1
processs ID 88645 called increment for 3
2
processs ID 88645 called increment for 5
3
4
5
sharding_round_robin_dispatch AFTER map(increment):
processs ID 88650 called increment for 1
processs ID 88650 called increment for 2
processs ID 88650 called increment for 3
processs ID 88650 called increment for 4
1
2
processs ID 88650 called increment for 5
3
4
5
I had initially posted that this didn't work with DataLoader2
either, but I realized there was a bug in my code.