litdata icon indicating copy to clipboard operation
litdata copied to clipboard

Fix dataset generation: deterministic per-index seeding and collate-compatible image format

Open dipeshbabu opened this issue 8 months ago • 3 comments

What does this PR do?

Issues Addressed

  1. Fixes #573: Two critical bugs in the Getting Started example:
    • Duplicate data when using optimize with num_workers > 1 due to unseeded randomness.
    • default_collate error caused by returning PIL Images incompatible with PyTorch batching.

Root Causes

  • Duplicate Data: Workers shared the global numpy random state, leading to identical random values across processes.
  • Collate Error: PIL Images cannot be batched by PyTorch’s default_collate.

Changes

  1. Deterministic Data Generation:
    • Seed numpy’s RNG uniquely per index using np.random.default_rng(seed=index).
    • Replace np.random.randint with the seeded generator’s rng.integers(...).
  2. Collate Compatibility:
    • Return images as numpy arrays instead of PIL Images.
  3. Documentation Updates:
    • Updated the Getting Started example to reflect both fixes.

Result

  • No duplicate data across workers.
  • StreamingDataLoader now works out-of-the-box with the example.
  • Improved efficiency (no runtime PIL-to-tensor conversions).

PR review

Anyone in the community is free to review the PR once the tests have passed.

Did you have fun?

Make sure you had fun coding 🙃

dipeshbabu avatar Apr 28 '25 19:04 dipeshbabu

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 79%. Comparing base (e789fb6) to head (4bb0eef). Report is 17 commits behind head on main.

Additional details and impacted files
@@         Coverage Diff         @@
##           main   #574   +/-   ##
===================================
  Coverage    79%    79%           
===================================
  Files        40     40           
  Lines      6098   6098           
===================================
  Hits       4818   4818           
  Misses     1280   1280           
:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Apr 28 '25 20:04 codecov[bot]

Hi @dipeshbabu, The provided sample is just a minimal demo — in real-world use cases, you’d typically work with actual .jpg or .png images, which get optimized and loaded as tensors directly.

If you’re saving PIL images directly, you can handle them either by subclassing the streaming dataset to apply transforms or by passing a custom collate_fn like:

def collate_fn(batch):
    return {
        "image": [sample["image"] for sample in batch],
        "class": [sample["class"] for sample in batch],
    }

train_dataloader = ld.StreamingDataLoader(train_dataset, collate_fn=collate_fn)

If you’re interested in making further contributions to litdata, we’d be happy to discuss and collaborate on our Discord — join us in the #litdata channel! Or please feel free to discuss directly over the issues.

bhimrazy avatar May 02 '25 09:05 bhimrazy

Hi @dipeshbabu IMO, it makes more sense to add collate_fn code in getting_started/stream.py that makes example complete and doesn't raise error immediately.

deependujha avatar May 14 '25 14:05 deependujha

Hey @dipeshbabu, just following up — if you're interested, adding the collate_fn directly in getting_started/stream.py and any other related places would be a great addition to round out the fix. Would be awesome to have that as part of your first contribution! 🙌

bhimrazy avatar May 23 '25 06:05 bhimrazy

Closing this since there’s been no response from the author.

deependujha avatar May 31 '25 09:05 deependujha

Let's include the collate_fn then.

bhimrazy avatar May 31 '25 10:05 bhimrazy