litdata
litdata copied to clipboard
How to correctly mix multiple `StreamingDataset` to create new data?
I would like to iterate over two StreamingDataset instances infinitely and yield new data derived from the output of both datasets. I am able to achieve this by wrapping two StreamingDataset instances in an IterableDataset as follows:
class MixedDataset(torch.utils.data.IterableDataset):
def __init__(self, dset_1: ld.StreamingDataset, dset_2: ld.StreamingDataset) -> None:
self.dset_1 = dset_1
self.dset_2 = dset_2
def __iter__(self) -> Iterator[torch.Tensor]:
dset_iter_1 = iter(self.dset_1)
dset_iter_2 = iter(self.dset_2)
while True:
try:
x_1 = next(dset_iter_1)
except StopIteration:
dset_iter_1 = iter(self.dset_1)
x_1 = next(dset_iter_1)
try:
x_2 = next(dset_iter_2)
except StopIteration:
dset_iter_2 = iter(self.dset_2)
x_2 = next(dset_iter_2)
yield do_stuff(x_1, x_2)
This works but I cannot use this new dataset with a StreamingDataloader since it does not inherit StreamingDataset. I can use a vanilla torch.utils.data.DataLoader, but am I missing out on key features provided by StreamingDataloader?
The CombinedStreamingDataset does not fit my use case as it yields items from one or the other dataset, from what I understand.
Hi @philgzl , did you tried with class MixedDataset(StreamingDataset)?
StreamingDataset also inherits from IterableDataset.
Hey @philgzl Yes, you should use the StreamingDataloader as it properly forward down the batch size and num workers to the dataset.
You can make a PR to convert this one: https://github.com/Lightning-AI/litData/blob/68d23cd8e1d9b4121227ddbe047b8bff27d8d929/src/litdata/streaming/dataloader.py#L584 into a warning and implement the required method, such as set_epoch, etc..
The best would be to use a protocol
@deependujha inheriting from StreamingDataset does indeed suppress the type check error in StreamingDataLoader.__init__ , but as pointed out by @tchaton it does not make StreamingDataLoader provide all its functionalities to the two wrapped StreamingDataset. This includes
- Forwarding down attributes such as
num_workers,batch_sizeorshuffle - State dict saving and loading
- Other stuff?
Would it make sense to implement a new dataset class ParallelStreamingDataset which takes as input N instances of StreamingDataset and yields tuples (x_1, x_2_, ... x_N)? This new class would take care of forwarding attributes and saving/loading state dicts similarly to CombinedStreamingDataset.
Hey @philgzl Feel free to make a PR to add support for it.
BTW, what is your use case ?
Will give it a try.
Use case is dynamic acoustic scene simulation for speech enhancement or source separation. I have datasets of clean speech files, datasets of noise files, and datasets of room impulse responses. I would like to avoid generating a fixed number of scenes and updating them every time I change an acoustic parameter. Creating them on-the-fly is more flexible and allows for a virtually infinite amount of training data.