streaming icon indicating copy to clipboard operation
streaming copied to clipboard

Making Streaming Dataset framework agnostic: Removing PyTorch dependency

Open Abhijit-2592 opened this issue 1 year ago • 5 comments

🚀 Feature Request

Hey MosaicML team! Thank you so much for this awesome project! I was wondering if there are any plans to make this framework agnostic: Remove the dependency from PyTorch.

Motivation

The general idea of StreamingDataset is very useful and I believe the ML community in general will be more thrilled if we decouple this from PyTorch.

Implementation

Here are my thoughts on how we can go about this:

  • The torch.utils.data.Dataset is a simple class with no dependencies with PyTorch (This is also true for the IterableDataset) which can be very easily re-implemented here.
  • However this gets a bit challenging when porting the distributed.py file. However this is where the CuPy project comes to rescue. We can have seamless interoperability between CuPy, Jax, Tensorflow and PyTorch Tensors via the dl_pack API with no copies. And most of the functions in the distributed.py file have similar implementations in CuPy's distributed API.
  • As for the StreamingDataLoader we can have this as an optional install if installing with PyTorch backend.
  • So my suggestion is if we use CuPy instead of PyTorch we can keep this framework neutral and also have 0 copy interoperability between Jax, TF and Torch.

Additional context

If made framework agnostic:

  • This can be used with tf.data pipelines which works well with Jax and Tensorflow.
  • Fits perfectly into keras.utils.Sequence this way we can also use it with Keras-3 which is compatible with TF/Jax/PyTorch backends.

Also I will be happy to extend my support on the same if you guys think this is a potential future direction!

Abhijit-2592 avatar Dec 26 '23 03:12 Abhijit-2592

Decoupling from PyTorch would be a hell of a project! We enthusiastically welcome your contributions. Let me list some objections that come to mind offhand -- what do you make of them?

  • StreamingDataset is designed exactly to how PyTorch DataLoader operates, with each rank iterating round-robin over a bunch of worker replicas which are typically fork/spawned upon iter, identical samples per DL requirement, etc. What's the cupy answer to get_worker_info()?

  • Our killer feature, the elastically deterministic mid-epoch checkpointing and resumption, currently depends on either our custom StreamingDataLoader subclass of DataLoader, or tracking time yourself like Composer (yes we built it two different ways), it's a core thing too.

  • We use numpy for some things, but no framework-specific Tensors, and no GPU. We assume GPU and interconnect are precious resources and hands-off. There is a tiny usage of torch dist for barriers in some critical places, which we set up as gloo and tear down if not already used IIRC. Theoretically you could swap them out with streaming/base/shared/barrier.py which is an inter-process barrier backed by FileLock (fcntl) and SharedMemory, which we currently use for worker sync as workers can't necessarily dist. It would be nice to remove that last bit of reliance on torch dist, generally speaking.

knighton avatar Dec 26 '23 05:12 knighton

@knighton thanks for your comment and support.

  1. I have a private port of the PyTorch dataloader which I kinda hacked for fun to remove the torch dependency and made it into a standalone package (It kinda works but I have not tested it fully). During that I remember porting the get_worker_info(). I’ll see if it works.
  2. I’m not sure what to do for the StreamingDataLoader(). Maybe as you suggested we can track time ourselves.
  3. I’ll try to use the barrier.py you suggested and try porting the distributed.py to see if we can remove the reliance on torch dist.

I’ll keep you updated on the same! Thanks!

Abhijit-2592 avatar Dec 26 '23 06:12 Abhijit-2592

Appreciate the updates.

I would recommend just reading our StreamingDataLoader for (2), as what it's doing/needs to do is very simple.

knighton avatar Dec 26 '23 06:12 knighton

Experimental PR to remove dependency on torch dist:

https://github.com/mosaicml/streaming/pull/552

knighton avatar Dec 27 '23 03:12 knighton

@knighton Wow! That was fast!

Abhijit-2592 avatar Dec 27 '23 18:12 Abhijit-2592