litdata
litdata copied to clipboard
Cycle option for StreamingDataLoader
🚀 Feature
A function or an argument in StreamingDataLoader to cycle the passed in StreamingDataset.
Motivation
Many training scenarios in CV involve training models with multiple epochs, while wanting to control the exact number of steps being trained, independent of the underlying dataset size. E.g., given a CombinedStreamingDataset of some length, restart its iterations when it is exhausted.
Pitch
I'm not quite sure how this should be done - maybe in iter method of StreamingDataLoader, we can catch the final iteration and restart it?
You could check PyTorch Lightning Cycle Loaders: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.utilities.combined_loader.html
Or create your own wrapper that iterates for a given number of steps.
Hi @Aceticia, ParallelStreamingDataset #576 has just been merged into litdata. Feel free to give it a try! You can find more details in the README under the sections Parallel Streaming and Cycle Datasets.
As mentioned by @lantiga in #576, the ability to cycle datasets using ParallelStreamingDataset is a nice option as it is, but this should probably be upstreamed to StreamingDataset in the future.
Idk if this issue should stay open so we don't forget.
Hi @philgzl, are you suggesting something like:
sd = ld.StreamingDataset("..", cycle=True)
and so when iter raises StopIteration, we don't increase epoch count, and just restart iter?
Mmh I was thinking of a similar solution to what was implemented in ParallelStreamingDataset:
dset = ld.StreamingDataset("..", length=100)
Iterating over the dataset once then yields 100 samples. If the dataset has less than 100 samples, we cycle and shuffle internally. If we iterate over the dataset a second time, we resume from where we left off without re-shuffling, and yield 100 samples again.
This way we can disentangle the epoch length (as in the number of items yielded by iter) from the actual number of samples in the dataset. I believe this is what OP meant.
I realize now maybe what you meant with cycle=True is the same with length=float("inf").
thanks for the clarification.
Similar to parallelSD, pass int or "inf". Sounds good to me.
Yes and then I guess this feature should be removed from ParallelStreamingDataset since we would just pass StreamingDataset instances which were already configured to cycle.