Make DistributedSampler stateful
🚀 The feature
Currently RandomSampler, BatchSampler are patched here https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L134-L135 to make them stateful and work out of the box with StatefulDataLoader.
It would be useful to consider making DistributedSampler (https://github.com/pytorch/pytorch/blob/2176ef7dfaf02dd6dbb8484a50c99d5fadf3ea0b/torch/utils/data/distributed.py#L13) also implement stateful methods and patch it in torchdata.
Motivation, pitch
So that users can use DistributedSampler also out of the box with checkpointing capability
Alternatives
Users would have implement the stateful interface for DistributedSampler but extending it
Additional context
No response
This currently isn't broken right? ie fast-forwarding the sampler will work, but may be inefficient. I'm OK either way for before/after release branch cut
Hi @gokulavasan @andrewkho ,
I found that current StatefulDataloader works well with DistributedSampler without any modifications.
Would you mind please explaining why it might be inefficient?
Thanks in advance.
This currently isn't broken right? ie fast-forwarding the sampler will work, but may be inefficient. I'm OK either way for before/after release branch cut
Hi @andrewkho , Does fast-forwarding here mean that the sampler would iterate from the head to the checkpointing point? If so, is it inefficient? An efficient way would be to jump directly to the checkpointing point, right?
Please correct me if my understanding is wrong. Thank you.
HI @ShoufaChen you're correct, it should work without modifications but may be slow for large tables. https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L47 Here is where we've done the conversion for RandomSampler and BatchSampler as examples.
You can see for example the default batch sampler calling next() to naively fast-forward the sampler.
Here's an example where you can see that increasing the samples to iterate through increases the time required to fast-forward, and when you get to very large scales (eg billion scale) this starts to slow down to order of minutes: https://colab.research.google.com/drive/1UlJAMqzaCjtbW4RPaaoHxGd9sjiKFk7O?usp=sharing