data icon indicating copy to clipboard operation
data copied to clipboard

No special treatment for ShardingRoundRobinDispatch

Open sehoffmann opened this issue 2 years ago • 2 comments

🚀 The feature

MPRS currently looks specifically for ShardingRoundRobinDispatch to determine the non-replicable part of the graph that gets executed in the main process before passing work to worker processes.

I would like to have the same synchronization behavior, but without the round robin or sharding aspect. Thus I would like to have a way to declare a part of the graph as non-replicable and to be executed before any worker processes. I.e. ShardingRoundRobinDispatch shouldn't get any special treatment, and features should be introduced that would allow users to be able the replicate its behavior with their own code if they wanted to.

Motivation, pitch

I'm preloading bigger chunks of data from which then worker processes produces samples in a sharded way. Currently, each worker process however loads each chunk separately from the rest since everyone is executng the same graph up the sharding point (later on, after the preloading). This results in memory consumption to scale linearly with the number of worker processes.

Instead, i would like each worker process to access the chunks via shared memory (this should be possible with posix's fork). For this I want the preloading of the chunks to be done once in the main process and then pass the same chunk (withouth sharding or round-robin behavior) to all worker processes. This is conceptually also much simpler for me to deal with, since the preloading is threaded as well.

Alternatives

No response

Additional context

No response

sehoffmann avatar Mar 15 '23 15:03 sehoffmann

For my specific use case .repeat(n_workers).sharding_round_robin_dispatch() can be used for now as a workaround (I believe).

sehoffmann avatar Mar 15 '23 15:03 sehoffmann

No this workaround actually does not work, as this example demonstrates:

my_worker_info = None

def abc(x):
    return f'Worker {my_worker_info.worker_id}: {x}'

def worker_init(dp, worker_info):
    global my_worker_info 
    my_worker_info = worker_info
    return dp
    
    
pipe = dp.iter.IterableWrapper(range(10))
pipe = pipe.repeat(2).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
pipe = pipe.map(abc)
pipe = pipe.sharding_filter(SHARDING_PRIORITIES.MULTIPROCESSING)

rs = SequentialReadingService(
    MultiProcessingReadingService(num_workers=2, worker_init_fn=worker_init)
)
dl = DataLoader2(pipe, reading_service=rs)

for x in dl:
    print(x)

Output:

Worker 0: 0
Worker 1: 0
Worker 0: 1
Worker 1: 1
Worker 0: 2
Worker 1: 2
Worker 0: 3
Worker 1: 3
Worker 0: 4
Worker 1: 4
Worker 0: 5
Worker 1: 5
Worker 0: 6
Worker 1: 6
Worker 0: 7
Worker 1: 7
Worker 0: 8
Worker 1: 8
Worker 0: 9
Worker 1: 9

Expected:

Worker 0: 0
Worker 1: 1
Worker 0: 2
Worker 1: 3
Worker 0: 4
Worker 1: 5
Worker 0: 6
Worker 1: 7
Worker 0: 8
Worker 1: 9

This is because I believe apply_sharding isn't called anymore

sehoffmann avatar Mar 15 '23 16:03 sehoffmann