litdata icon indicating copy to clipboard operation
litdata copied to clipboard

Warning Message When Using StreamingDataset with DDP

Open taemincho opened this issue 1 year ago • 2 comments

🐛 Bug

When utilizing the StreamingDataset to read data directly from AWS S3 with Distributed Data Parallel (DDP), the following warning message is displayed:

lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:122 Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.

To Reproduce

Steps to reproduce the behavior:

  1. Create the litdata.StreamingDataset
  2. Create dataLoader using litdata.StreamingDataLoader or torch.utils.data.DataLoader
  3. set batch_size > 1
  4. train using DDP

Code sample

Datamodule

import lightning.pytorch as pl
from litdata import StreamingDataset, StreamingDataLoader

def collate_fn(samples):
    # some data modifications
    return samples

class MyDataModule(pl.LightningDataModule):
    def __init__(self, data_path, **kwargs):
        super().__init__()
        self.data_path = data_path

    def setup(self, stage):
        if "s3://" in self.data_path:
            self.dataset = StreamingDataset(self.data_path, shuffle=True)

    def train_dataloader(self):
        return StreamingDataLoader(
            self.dataset,
            batch_size=16,
            shuffle=True,
            num_workers=4,
            collate_fn=collate_fn,
            drop_last=True,
        )

Training

datamodule = MyDataModule("s3://my_bucket")

trainer = pl.Trainer(
    logger=False
    max_epochs=100000,
    precision="16-mixed",
)

trainer.fit(model, datamodule=datamodule, ckpt_path="last")

Expected behavior

No warning message should be displayed during training.

Environment

  • PyTorch Version (e.g., 1.0): 2.3
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.10.12
  • CUDA/cuDNN version: 12.4
  • GPU models and configuration: two Nvidia RTX 4090
  • Any other relevant information: Lightning 2.3.0

Additional context

taemincho avatar Jun 14 '24 15:06 taemincho

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Jun 14 '24 15:06 github-actions[bot]

Yes, it is normal. All good @taemincho

tchaton avatar Jun 14 '24 15:06 tchaton