data icon indicating copy to clipboard operation
data copied to clipboard

Make DistributedSampler stateful

Open gokulavasan opened this issue 1 year ago • 1 comments

🚀 The feature

Currently RandomSampler, BatchSampler are patched here https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L134-L135 to make them stateful and work out of the box with StatefulDataLoader.

It would be useful to consider making DistributedSampler (https://github.com/pytorch/pytorch/blob/2176ef7dfaf02dd6dbb8484a50c99d5fadf3ea0b/torch/utils/data/distributed.py#L13) also implement stateful methods and patch it in torchdata.

Motivation, pitch

So that users can use DistributedSampler also out of the box with checkpointing capability

Alternatives

Users would have implement the stateful interface for DistributedSampler but extending it

Additional context

No response

gokulavasan avatar Jun 10 '24 22:06 gokulavasan

This currently isn't broken right? ie fast-forwarding the sampler will work, but may be inefficient. I'm OK either way for before/after release branch cut

andrewkho avatar Jun 12 '24 17:06 andrewkho

Hi @gokulavasan @andrewkho ,

I found that current StatefulDataloader works well with DistributedSampler without any modifications.

Would you mind please explaining why it might be inefficient?

Thanks in advance.

ShoufaChen avatar Jul 07 '24 04:07 ShoufaChen

This currently isn't broken right? ie fast-forwarding the sampler will work, but may be inefficient. I'm OK either way for before/after release branch cut

Hi @andrewkho , Does fast-forwarding here mean that the sampler would iterate from the head to the checkpointing point? If so, is it inefficient? An efficient way would be to jump directly to the checkpointing point, right?

Please correct me if my understanding is wrong. Thank you.

ShoufaChen avatar Jul 08 '24 06:07 ShoufaChen

HI @ShoufaChen you're correct, it should work without modifications but may be slow for large tables. https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L47 Here is where we've done the conversion for RandomSampler and BatchSampler as examples.

You can see for example the default batch sampler calling next() to naively fast-forward the sampler.

Here's an example where you can see that increasing the samples to iterate through increases the time required to fast-forward, and when you get to very large scales (eg billion scale) this starts to slow down to order of minutes: https://colab.research.google.com/drive/1UlJAMqzaCjtbW4RPaaoHxGd9sjiKFk7O?usp=sharing

andrewkho avatar Jul 08 '24 16:07 andrewkho