data icon indicating copy to clipboard operation
data copied to clipboard

Make accessing WorkerInfo from within a DataPipe more convenient

Open sehoffmann opened this issue 2 years ago • 8 comments

🚀 The feature

import torchdata.datapipes as dp
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
from torchdata.dataloader2 import MultiProcessingReadingService, DataLoader2

my_worker_info = None

def abc(x):
    return x * my_worker_info.worker_id

def worker_init(dp, worker_info):
    global my_worker_info 
    my_worker_info = worker_info
    return dp
    
    
pipe = dp.iter.IterableWrapper(range(10))
pipe = pipe.map(abc)
pipe = pipe.sharding_filter(SHARDING_PRIORITIES.MULTIPROCESSING)

rs = MultiProcessingReadingService(num_workers=2, worker_init_fn=worker_init)
dl = DataLoader2(pipe, reading_service=rs)

for x in dl:
    print(x)

Output:

0
1
0
3
0
5
0
7
0
9

This seems to be the only way to my knowledge to access the WorkerInfo from within a DataPipe when using Dataloader2. Global state is obviously awkward and becomes a problem for larger coebases that aren't toy examples. It would be good if there was a more convenient way (and also uniform way wrt Dataloader) a kin to get_worker_info.

Traversing the graph and calling set_worker_info if available would be a good option for this IMO.

Motivation, pitch

I want to easily access the current WorkerInfo from my datapipe.

Alternatives

No response

Additional context

No response

sehoffmann avatar Mar 15 '23 14:03 sehoffmann

So, I guess you want a DataPipe behaves differently based WorkerInfo. I think adding get_worker_info is a good feature request.

However, set_worker_info to each DataPipe might be too much as not all of DataPipe would need it and it requires a registry on DataPipe.

ejguan avatar Mar 17 '23 14:03 ejguan

Hey @ejguan

I think I would be fine with either. get_worker_info() however would have global state(?) and would produce issues when multiple independent datapipes are iterated in parallel (i know, a bit hypothetical, just saying though)

However, set_worker_info to each DataPipe might be too much as not all of DataPipe would need it and it requires a registry on DataPipe.

No, I don't think so. It would be in line with how sharding and shuffling work at the moment. I.e. one just needs to do something similar to:


def apply_worker_info(datapipe, worker_info):
    graph = traverse_dps(datapipe)
    all_pipes = get_all_graph_pipes(graph)
    for pipe in all_pipes if hasattr(pipe, 'set_worker_info'):
         pipe.set_worker_info(worker_info)
    return datapipe

sehoffmann avatar Mar 17 '23 14:03 sehoffmann

get_worker_info() however would have global state(?) and would produce issues when multiple independent datapipes are iterated in parallel (i know, a bit hypothetical, just saying though)

Since they are running on a separate subprocesses, it should be fine.

No, I don't think so. It would be in line with how sharding and shuffling work at the moment. I.e. one just needs to do something similar to:

I see what you mean. So, you want some custom DataPipe to accept it. I have concern on it a little bit how to provide this information to the DataPipe in the dispatching_prcoess or in the main process.

ejguan avatar Mar 17 '23 14:03 ejguan

Isn't the worker information only relevant when using the MPRS, DistributedReadingService, or both? I don't see how it is technical any different from e.g. sharding information.

Also, one thing to keep in mind with all these interfaces (including sharding and shuffling), is that people also need to be able to set them easily in their own ReadingService's. For instance, I'm rolling my own HorovodReadingService.

On a side note: Is there interests for a PR for the HorovodReadingService?

sehoffmann avatar Mar 17 '23 15:03 sehoffmann

Isn't the worker information only relevant when using the MPRS, DistributedReadingService, or both? I don't see how it is technical any different from e.g. sharding information.

Dispatching process is tied to MPRS as well. And, like we discussed, there might be partial DataPipe remaining in the main process when MPRS gets involved. So, in those cases, we have to ask users/developers to handle if WorkerInfo is not provided.

On a side note: Is there interests for a PR for the HorovodReadingService?

It would be good if you can share more context like a RFC issue.

ejguan avatar Mar 17 '23 15:03 ejguan

@ejguan Some specific use case that i would like to handle with this:

pipe = pipe.repeat(N_workers).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)

Here, I would like to introduce a custom operation that doesn't know the number of worker a-priori. A set_worker_info (or global get_worker_info) feature should also take in a sharding priority as argument so that we can specifiy what kind of worker info we are interested in, i.e. distributed (mpi_rank, mpi_size) vs multiprocessing (process_rank, process_count).

sehoffmann avatar Mar 21 '23 19:03 sehoffmann

Technical speaking, you can add a Adapter object to DataLoader2 to achieve in-place graph modification, because you should be able to know the value of worker numbers and distributed ranks at initialization time of DataLoader2.

If you want to access the information for specific MP worker, you probably need a get_worker_info function.

ejguan avatar Mar 21 '23 19:03 ejguan

Yes, for now I can workaround this. I just wrote this as an example of a real use case and its specific requirements and thought that it might be helpful for you when deciding on a design.

sehoffmann avatar Mar 21 '23 21:03 sehoffmann