streaming
streaming copied to clipboard
Making Streaming Dataset framework agnostic: Removing PyTorch dependency
🚀 Feature Request
Hey MosaicML team! Thank you so much for this awesome project! I was wondering if there are any plans to make this framework agnostic: Remove the dependency from PyTorch.
Motivation
The general idea of StreamingDataset
is very useful and I believe the ML community in general will be more thrilled if we decouple this from PyTorch.
Implementation
Here are my thoughts on how we can go about this:
- The torch.utils.data.Dataset is a simple class with no dependencies with PyTorch (This is also true for the
IterableDataset
) which can be very easily re-implemented here. - However this gets a bit challenging when porting the distributed.py file. However this is where the
CuPy
project comes to rescue. We can have seamless interoperability between CuPy, Jax, Tensorflow and PyTorch Tensors via thedl_pack
API with no copies. And most of the functions in thedistributed.py
file have similar implementations in CuPy's distributed API. - As for the
StreamingDataLoader
we can have this as an optional install if installing with PyTorch backend. - So my suggestion is if we use
CuPy
instead ofPyTorch
we can keep this framework neutral and also have 0 copy interoperability between Jax, TF and Torch.
Additional context
If made framework agnostic:
- This can be used with
tf.data
pipelines which works well with Jax and Tensorflow. - Fits perfectly into
keras.utils.Sequence
this way we can also use it with Keras-3 which is compatible with TF/Jax/PyTorch backends.
Also I will be happy to extend my support on the same if you guys think this is a potential future direction!
Decoupling from PyTorch would be a hell of a project! We enthusiastically welcome your contributions. Let me list some objections that come to mind offhand -- what do you make of them?
-
StreamingDataset is designed exactly to how PyTorch DataLoader operates, with each rank iterating round-robin over a bunch of worker replicas which are typically fork/spawned upon iter, identical samples per DL requirement, etc. What's the cupy answer to
get_worker_info()
? -
Our killer feature, the elastically deterministic mid-epoch checkpointing and resumption, currently depends on either our custom StreamingDataLoader subclass of DataLoader, or tracking time yourself like Composer (yes we built it two different ways), it's a core thing too.
-
We use numpy for some things, but no framework-specific Tensors, and no GPU. We assume GPU and interconnect are precious resources and hands-off. There is a tiny usage of torch dist for barriers in some critical places, which we set up as gloo and tear down if not already used IIRC. Theoretically you could swap them out with
streaming/base/shared/barrier.py
which is an inter-process barrier backed by FileLock (fcntl) and SharedMemory, which we currently use for worker sync as workers can't necessarily dist. It would be nice to remove that last bit of reliance on torch dist, generally speaking.
@knighton thanks for your comment and support.
- I have a private port of the PyTorch dataloader which I kinda hacked for fun to remove the torch dependency and made it into a standalone package (It kinda works but I have not tested it fully). During that I remember porting the get_worker_info(). I’ll see if it works.
- I’m not sure what to do for the StreamingDataLoader(). Maybe as you suggested we can track time ourselves.
- I’ll try to use the barrier.py you suggested and try porting the distributed.py to see if we can remove the reliance on torch dist.
I’ll keep you updated on the same! Thanks!
Appreciate the updates.
I would recommend just reading our StreamingDataLoader
for (2), as what it's doing/needs to do is very simple.
Experimental PR to remove dependency on torch dist:
https://github.com/mosaicml/streaming/pull/552
@knighton Wow! That was fast!