data
data copied to clipboard
No special treatment for ShardingRoundRobinDispatch
🚀 The feature
MPRS currently looks specifically for ShardingRoundRobinDispatch to determine the non-replicable part of the graph that gets executed in the main process before passing work to worker processes.
I would like to have the same synchronization behavior, but without the round robin or sharding aspect. Thus I would like to have a way to declare a part of the graph as non-replicable and to be executed before any worker processes. I.e. ShardingRoundRobinDispatch shouldn't get any special treatment, and features should be introduced that would allow users to be able the replicate its behavior with their own code if they wanted to.
Motivation, pitch
I'm preloading bigger chunks of data from which then worker processes produces samples in a sharded way. Currently, each worker process however loads each chunk separately from the rest since everyone is executng the same graph up the sharding point (later on, after the preloading). This results in memory consumption to scale linearly with the number of worker processes.
Instead, i would like each worker process to access the chunks via shared memory (this should be possible with posix's fork). For this I want the preloading of the chunks to be done once in the main process and then pass the same chunk (withouth sharding or round-robin behavior) to all worker processes. This is conceptually also much simpler for me to deal with, since the preloading is threaded as well.
Alternatives
No response
Additional context
No response
For my specific use case .repeat(n_workers).sharding_round_robin_dispatch() can be used for now as a workaround (I believe).
No this workaround actually does not work, as this example demonstrates:
my_worker_info = None
def abc(x):
return f'Worker {my_worker_info.worker_id}: {x}'
def worker_init(dp, worker_info):
global my_worker_info
my_worker_info = worker_info
return dp
pipe = dp.iter.IterableWrapper(range(10))
pipe = pipe.repeat(2).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
pipe = pipe.map(abc)
pipe = pipe.sharding_filter(SHARDING_PRIORITIES.MULTIPROCESSING)
rs = SequentialReadingService(
MultiProcessingReadingService(num_workers=2, worker_init_fn=worker_init)
)
dl = DataLoader2(pipe, reading_service=rs)
for x in dl:
print(x)
Output:
Worker 0: 0
Worker 1: 0
Worker 0: 1
Worker 1: 1
Worker 0: 2
Worker 1: 2
Worker 0: 3
Worker 1: 3
Worker 0: 4
Worker 1: 4
Worker 0: 5
Worker 1: 5
Worker 0: 6
Worker 1: 6
Worker 0: 7
Worker 1: 7
Worker 0: 8
Worker 1: 8
Worker 0: 9
Worker 1: 9
Expected:
Worker 0: 0
Worker 1: 1
Worker 0: 2
Worker 1: 3
Worker 0: 4
Worker 1: 5
Worker 0: 6
Worker 1: 7
Worker 0: 8
Worker 1: 9
This is because I believe apply_sharding isn't called anymore