litdata icon indicating copy to clipboard operation
litdata copied to clipboard

Bug: Inconsistent Behavior with StreamingDataloader loading states (specific to CombinedStreamingDataset)

Open bhimrazy opened this issue 1 year ago • 0 comments

🐛 Bug

Bug: Inconsistent Behavior in StreamingDataLoader After Loading States (Specific to CombinedStreamingDataset)

Description:
The StreamingDataLoader exhibits inconsistent behavior when handling loaded states across different scenarios. Specifically, issues arise when iterating over the dataloader after loading states with a complete or partial first epoch.

This bug is an extension of #316 for CombinedStreamingDataset.

To Reproduce

Create Optimized Dataset
from litdata import optimize


def random_data(index):
    return index

if __name__ == "__main__":
    datasets = ["dataset1", "dataset2"]
    for dataset in datasets:
        optimize(fn=random_data, inputs=list(range(50)), output_dir=dataset, num_workers=4, chunk_bytes="64MB")

Bugs

  1. IndexError raised when loading dataloader state without prior iteration

    from litdata import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset
    
    if __name__ == "__main__":
        dataset1 = StreamingDataset("dataset1")
        dataset2 = StreamingDataset("dataset2")
        datasets = [dataset1, dataset2]
        combined_dataset = CombinedStreamingDataset(datasets=datasets)
        dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=4)
    
        dataloader.load_state_dict(dataloader.state_dict())
    

    Output

    Traceback (most recent call last):
      File "/Users/bhimrajyadav/litdata/test_combined_dataset.py", line 10, in <module>
        dataloader.load_state_dict(dataloader.state_dict())
                                   ^^^^^^^^^^^^^^^^^^^^^^^
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/dataloader.py", line 668, in state_dict
        num_samples_yieled = [0 for _ in range(len(list(self._num_samples_yielded_combined.values())[0]))]
                                                   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
    IndexError: list index out of range                                        
    
  2. After loading the dataloader state following the completion of the first epoch, a ValueError is thrown (previously an IndexError, see clearer example in issue #363).

    from litdata import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset
    
    if __name__ == "__main__":
        dataset1 = StreamingDataset("dataset1")
        dataset2 = StreamingDataset("dataset2")
        datasets = [dataset1, dataset2]
        combined_dataset = CombinedStreamingDataset(datasets=datasets)
        dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=4)
    
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx == 0:
                print("\nEpoch", dataloader.current_epoch)
            print(batch.numpy(), end=" ")
    
        dataloader.load_state_dict(dataloader.state_dict())
    
    
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx == 0:
                print("\nEpoch", dataloader.current_epoch)
            print(batch.numpy(), end=" ")
    

    Output

      File "/Users/bhimrajyadav/itdata/venv/lib/python3.12/site-packages/litdata/streaming/combined.py", line 160, in __iter__
        self._iterator = _CombinedDatasetIterator(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/combined.py", line 208, in __init__
        self._dataset_iters = [iter(dataset) for dataset in datasets]
                               ^^^^^^^^^^^^^
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/dataset.py", line 223, in __iter__
        self._validate_state_dict()
      File "/Users/bhimrajyadav/litdata/venv/lib/python3.12/site-packages/litdata/streaming/dataset.py", line 479, in _validate_state_dict
        raise ValueError(
    ValueError: The provided `num_samples_yielded` state is greater than the dataset length. Found `51` instead of `50`.                            
    
  3. After loading the dataloader state with a partially completed first epoch, the dataloader does not reset correctly upon completing the epoch.

    • Additional details will be added.

Environment

  • PyTorch Version (e.g., 1.0): 2.4.0
  • OS (e.g., Linux): Mac OS
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.12.4
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

bhimrazy avatar Aug 14 '24 18:08 bhimrazy