litdata icon indicating copy to clipboard operation
litdata copied to clipboard

Cycle option for StreamingDataLoader

Open Aceticia opened this issue 8 months ago • 8 comments

🚀 Feature

A function or an argument in StreamingDataLoader to cycle the passed in StreamingDataset.

Motivation

Many training scenarios in CV involve training models with multiple epochs, while wanting to control the exact number of steps being trained, independent of the underlying dataset size. E.g., given a CombinedStreamingDataset of some length, restart its iterations when it is exhausted.

Pitch

I'm not quite sure how this should be done - maybe in iter method of StreamingDataLoader, we can catch the final iteration and restart it?

Aceticia avatar Mar 24 '25 06:03 Aceticia

You could check PyTorch Lightning Cycle Loaders: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.utilities.combined_loader.html

Or create your own wrapper that iterates for a given number of steps.

tchaton avatar Mar 26 '25 08:03 tchaton

Hi @Aceticia, ParallelStreamingDataset #576 has just been merged into litdata. Feel free to give it a try! You can find more details in the README under the sections Parallel Streaming and Cycle Datasets.

bhimrazy avatar May 26 '25 08:05 bhimrazy

As mentioned by @lantiga in #576, the ability to cycle datasets using ParallelStreamingDataset is a nice option as it is, but this should probably be upstreamed to StreamingDataset in the future.

Idk if this issue should stay open so we don't forget.

philgzl avatar May 26 '25 09:05 philgzl

Hi @philgzl, are you suggesting something like:

sd = ld.StreamingDataset("..", cycle=True)

and so when iter raises StopIteration, we don't increase epoch count, and just restart iter?

deependujha avatar May 26 '25 10:05 deependujha

Mmh I was thinking of a similar solution to what was implemented in ParallelStreamingDataset:

dset = ld.StreamingDataset("..", length=100)

Iterating over the dataset once then yields 100 samples. If the dataset has less than 100 samples, we cycle and shuffle internally. If we iterate over the dataset a second time, we resume from where we left off without re-shuffling, and yield 100 samples again.

This way we can disentangle the epoch length (as in the number of items yielded by iter) from the actual number of samples in the dataset. I believe this is what OP meant.

philgzl avatar May 26 '25 10:05 philgzl

I realize now maybe what you meant with cycle=True is the same with length=float("inf").

philgzl avatar May 26 '25 10:05 philgzl

thanks for the clarification.

Similar to parallelSD, pass int or "inf". Sounds good to me.

deependujha avatar May 26 '25 10:05 deependujha

Yes and then I guess this feature should be removed from ParallelStreamingDataset since we would just pass StreamingDataset instances which were already configured to cycle.

philgzl avatar May 26 '25 10:05 philgzl