litdata icon indicating copy to clipboard operation
litdata copied to clipboard

How to correctly mix multiple `StreamingDataset` to create new data?

Open philgzl opened this issue 7 months ago • 5 comments

I would like to iterate over two StreamingDataset instances infinitely and yield new data derived from the output of both datasets. I am able to achieve this by wrapping two StreamingDataset instances in an IterableDataset as follows:

class MixedDataset(torch.utils.data.IterableDataset):
    def __init__(self, dset_1: ld.StreamingDataset, dset_2: ld.StreamingDataset) -> None:
        self.dset_1 = dset_1
        self.dset_2 = dset_2

    def __iter__(self) -> Iterator[torch.Tensor]:
        dset_iter_1 = iter(self.dset_1)
        dset_iter_2 = iter(self.dset_2)
        while True:
            try:
                x_1 = next(dset_iter_1)
            except StopIteration:
                dset_iter_1 = iter(self.dset_1)
                x_1 = next(dset_iter_1)
            try:
                x_2 = next(dset_iter_2)
            except StopIteration:
                dset_iter_2 = iter(self.dset_2)
                x_2 = next(dset_iter_2)
            yield do_stuff(x_1, x_2)

This works but I cannot use this new dataset with a StreamingDataloader since it does not inherit StreamingDataset. I can use a vanilla torch.utils.data.DataLoader, but am I missing out on key features provided by StreamingDataloader?

The CombinedStreamingDataset does not fit my use case as it yields items from one or the other dataset, from what I understand.

philgzl avatar Apr 13 '25 20:04 philgzl

Hi @philgzl , did you tried with class MixedDataset(StreamingDataset)?

StreamingDataset also inherits from IterableDataset.

deependujha avatar Apr 14 '25 13:04 deependujha

Hey @philgzl Yes, you should use the StreamingDataloader as it properly forward down the batch size and num workers to the dataset.

You can make a PR to convert this one: https://github.com/Lightning-AI/litData/blob/68d23cd8e1d9b4121227ddbe047b8bff27d8d929/src/litdata/streaming/dataloader.py#L584 into a warning and implement the required method, such as set_epoch, etc..

The best would be to use a protocol

tchaton avatar Apr 15 '25 09:04 tchaton

@deependujha inheriting from StreamingDataset does indeed suppress the type check error in StreamingDataLoader.__init__ , but as pointed out by @tchaton it does not make StreamingDataLoader provide all its functionalities to the two wrapped StreamingDataset. This includes

  • Forwarding down attributes such as num_workers, batch_size or shuffle
  • State dict saving and loading
  • Other stuff?

Would it make sense to implement a new dataset class ParallelStreamingDataset which takes as input N instances of StreamingDataset and yields tuples (x_1, x_2_, ... x_N)? This new class would take care of forwarding attributes and saving/loading state dicts similarly to CombinedStreamingDataset.

philgzl avatar Apr 15 '25 11:04 philgzl

Hey @philgzl Feel free to make a PR to add support for it.

BTW, what is your use case ?

tchaton avatar Apr 15 '25 21:04 tchaton

Will give it a try.

Use case is dynamic acoustic scene simulation for speech enhancement or source separation. I have datasets of clean speech files, datasets of noise files, and datasets of room impulse responses. I would like to avoid generating a fixed number of scenes and updating them every time I change an acoustic parameter. Creating them on-the-fly is more flexible and allows for a virtually infinite amount of training data.

philgzl avatar Apr 15 '25 22:04 philgzl