Fix dataset generation: deterministic per-index seeding and collate-compatible image format
What does this PR do?
Issues Addressed
- Fixes #573: Two critical bugs in the Getting Started example:
- Duplicate data when using
optimizewithnum_workers > 1due to unseeded randomness. default_collateerror caused by returning PIL Images incompatible with PyTorch batching.
- Duplicate data when using
Root Causes
- Duplicate Data: Workers shared the global
numpyrandom state, leading to identical random values across processes. - Collate Error: PIL Images cannot be batched by PyTorch’s
default_collate.
Changes
- Deterministic Data Generation:
- Seed
numpy’s RNG uniquely perindexusingnp.random.default_rng(seed=index). - Replace
np.random.randintwith the seeded generator’srng.integers(...).
- Seed
- Collate Compatibility:
- Return images as numpy arrays instead of PIL Images.
- Documentation Updates:
- Updated the Getting Started example to reflect both fixes.
Result
- No duplicate data across workers.
StreamingDataLoadernow works out-of-the-box with the example.- Improved efficiency (no runtime PIL-to-tensor conversions).
PR review
Anyone in the community is free to review the PR once the tests have passed.
Did you have fun?
Make sure you had fun coding 🙃
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 79%. Comparing base (
e789fb6) to head (4bb0eef). Report is 17 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #574 +/- ##
===================================
Coverage 79% 79%
===================================
Files 40 40
Lines 6098 6098
===================================
Hits 4818 4818
Misses 1280 1280
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Hi @dipeshbabu,
The provided sample is just a minimal demo — in real-world use cases, you’d typically work with actual .jpg or .png images, which get optimized and loaded as tensors directly.
If you’re saving PIL images directly, you can handle them either by subclassing the streaming dataset to apply transforms or by passing a custom collate_fn like:
def collate_fn(batch):
return {
"image": [sample["image"] for sample in batch],
"class": [sample["class"] for sample in batch],
}
train_dataloader = ld.StreamingDataLoader(train_dataset, collate_fn=collate_fn)
If you’re interested in making further contributions to litdata, we’d be happy to discuss and collaborate on our Discord — join us in the #litdata channel! Or please feel free to discuss directly over the issues.
Hi @dipeshbabu
IMO, it makes more sense to add collate_fn code in getting_started/stream.py that makes example complete and doesn't raise error immediately.
Hey @dipeshbabu, just following up — if you're interested, adding the collate_fn directly in getting_started/stream.py and any other related places would be a great addition to round out the fix. Would be awesome to have that as part of your first contribution! 🙌
Closing this since there’s been no response from the author.
Let's include the collate_fn then.