accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

Add early support for `torchdata.stateful_dataloader.StatefulDataLoader` within the `Accelerator`

Open byi8220 opened this issue 1 year ago • 1 comments
trafficstars

What does this PR do?

Fixes https://github.com/huggingface/accelerate/issues/2859

This PR does the following:

  1. Added a new field use_stateful_dataloader in DataLoaderConfiguration. Passing this into the config makes it so that all DataLoaders prepared and returned by the Accelerator are StatefulDataLoader objects from the torchdata library
  2. Create a class DataLoaderAdapter which can wrap around and act as either PyTorch's DataLoader, or other variants of it such as StatefulDataLoader
  3. Refactor DataLoaderShard, DataLoaderDispatcher, and SkipDataLoader to inherit from DataLoaderAdapter instead of DataLoader

Testing

Added new unit tests to test that StatefulDataLoader can be dropped in and loaded and saved from.

Caveats

  • The torchdata package may have conflicts with accelerate, see https://github.com/huggingface/accelerate/issues/2894
    • However, if torchdata is not installed, all existing tests pass, suggesting this is not regressive.
  • torchdata.stateful_dataloader.StatefulDataLoader is only available in the beta build of torchdata, this is not a stable feature.
  • Adding another dependency (on a nightly package) means that almost none of the tests added in this PR is done underneath the existing images or imports.
  • This has only been tested on my local workstation using a single GPU.
  • The implementation of DataLoaderAdapter is somewhat invasive and uses some questionable reflection

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [X] Did you read the contributor guideline, Pull Request section?
  • [X] Was this discussed/approved via a Github issue (see above)
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [X] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@muellerzr

byi8220 avatar Jun 26 '24 22:06 byi8220

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Can you verify the tests pass on a multi-GPU system?

I have not written any tests specifically for multi-gpu. However, the PR as is passes on a 2 GPU cloud instance I spun up on GCP for the tests added so far.

Would the tests I've put into the PR so far be enough, or should I add specific multi-GPU tests? (I only have 1 GPU on my physical workstation so it costs me money to run these tests :sweat:)

Also would be good to add torchdata in the test requirements in the setup.py

I think StatefulDataLoader is still a beta feature and not officially included in any stable release. Their docs state it's only available in nightly builds for now, which are not found on PyPI.

According to https://stackoverflow.com/q/68809295 you really shouldn't be trying to specify extra_index_url in setup.py.

I'm not sure of a solution besides between trying to hack it out or waiting for the required features in torchdata to be released into stable.

byi8220 avatar Jul 15 '24 20:07 byi8220

Hi @byi8220, will be reviewing this in the next few days!

muellerzr avatar Jul 29 '24 16:07 muellerzr

Is there an issue with just duplicating the code into each DataLoader type instead of introducing a new subclass?

There's one obvious way to get this feature working without some magical reflection tricks:

  1. Create equivalent StatefulDataLoaderDispatcher, StatefulDataLoaderShard, StatefulSkipDataLoader classes which are identical except that it inherits from StatefulDataLoader and implement the state_dict functions.
  2. Create helper functions such as create_data_loader_dispatcher() which take in all the constructor args and pick whether to create a StatefulDataLoader or regular DataLoader
  3. Replace all constructor calls with these helper functions.

Doing that would add a ton of code duplication, but it should be mostly functionally equivalent to what I have now. The only thing you would need to be careful with moving forward is now you must always use the helper functions to create these objects instead of the class constructors.

If you want to avoid evil dynamic reflection, I could refactor this PR to do the above instead without too much effort. It will make dataloader.py look extremely messy. I'll make a branch to demonstrate those changes, but it feels like I'm littering countless small checks of "do we have a stateful dataloader?' around the code.

There's some other ideas I thought about writing this PR, mostly involving composing DataLoader into these classes, but most of them felt kinda wrong or that they might break something.

byi8220 avatar Jul 29 '24 18:07 byi8220

Actually, trying to draft up the change I just mentioned reminded me why I gave up on duplicating the code in the first place. It's a mess. Issues I needed to workaround include:

  1. Since StatefulDataLoader inherits from DataLoader, attempting something like having StatefulDataLoaderDispatcher inherit from DataLoaderDispatcher creates a diamond problem. This leads to a ton of messy MRO issues which feel like too much mental effort to process over simply just not having this inheritance.
  2. By separating the types, there is a bunch of boilerplate handling code that needs to consider the exact details of the DataLoader being passed around.
  3. The existing issues with this being a feature depending on a nightly, unstable build of torchdata remain (CI gap, needing dependency checking boilerplate, weirdly written tests, so on)

I'll defend my DataLoaderAdapter<T> : T solution since I find it elegant, but I recognize it's horribly confusing and invokes some pretty cursed stuff, so I'm in full support of getting rid of it for something more overt and readable.

If you want to see what it would look like, I've made those changes in a separate branch.

Here's the diff between this PR and the above branch: https://github.com/byi8220/accelerate/compare/stateful-dataloader...byi8220:accelerate:stateful-dataloader-2)

@muellerzr If the code above looks better, I can merge those into this PR's working branch and move on with that instead.

byi8220 avatar Jul 29 '24 22:07 byi8220

TBH I'm not sure if I like either. Let me work on this today and tommorow and see where I end up. Whichever solution winds up being similar to yours (unless a third one pops up) we'll go with it.

muellerzr avatar Jul 31 '24 20:07 muellerzr

TBH I'm not sure if I like either.

It's tricky. My understanding of the problem is this: We want to do something like this"If use_stateful_dataloader==True, then create a StatefulDataLoader instead of a DataLoader."

However, based on how classes like DataLoaderShard are currently implemented, what we have to do in practice is "If use_stateful_dataloader==True, then make this class inherit from StatefulDataLoader instead of DataLoader." This is a lot more awkward to code.

I have one more idea for a solution, which is to have those classes create a base_dataloader and manually code the passthrough to methods and properties. I feel like this is a bit fragile, and isn't too different from my original solution.

byi8220 avatar Jul 31 '24 21:07 byi8220

@byi8220 can you resolve the PR's and then I think we're okay to merge this.

muellerzr avatar Aug 20 '24 12:08 muellerzr

As a final step, we likely want to update save_state/load_state to resume the dataloaders at this point.

muellerzr avatar Aug 20 '24 12:08 muellerzr

@BenjaminBossan Thanks for the review. Just addressed the comments on the PR.

The end result is still something that I'm afraid will one day cause a hard to debug issue, but I can't say what exactly would be a better solution.

The ultimate intent of this code is something like "Sometimes I want a DataLoaderDispatcher that inherits from DataLoader, but other times I want a DataLoaderDispatcher that inherits from StatefulDataLoader."

Imo, the less magical alternative would be to explicitly duplicate each DataLoader derivative that accelerate defines into a stateful version. I.e. manually create the classes StatefulDataLoaderDispatcher, StatefulDataLoaderShard, StatefulSkipDataLoader. I wrote up this alternative in a separate branch (diffed by https://github.com/byi8220/accelerate/compare/stateful-dataloader...byi8220:accelerate:stateful-dataloader-2), but it leads to quite a lot of code duplication and also looks messy.

I have to admit I only skimmed the tests but they look very well done, so together with the existing ones they should hopefully avoid regressions.

I've tested this locally on my 1 GPU home workstation + a 2xGPU cloud instance (which costs me a few dollars every time I want to run the test suite :disappointed: ...) The fact that all tests pass for me when not using this feature, regardless of if the required torchdata version is not installed gives confidence that it's not causing a breaking regression.

This is my first real PR into accelerate, so I added the sanity and happy test cases I could think of based on my limited context, so I might have just been guessing on what's sufficient.

The tests highlighted one small thing though, the fact that to fully stop using a dataloader in the middle you have to call dataloader.end(), but this might just be unavoidable. If the use case of StatefulDataLoader is to restart the entire program from a checkpoint maybe it's not a big issue: https://github.com/huggingface/accelerate/pull/2895/files#diff-68b278b14afa2e1ea337bb5e13d122f6d074c8bf0f0b83bef779eac6f4ba7f9aR724-R726

One thing I would like to see is an addition to the docs to explain what stateful data loaders are, why users may want to use them, and how they can use them.

Imo this might be better in a separate PR, once the code is checked in?

byi8220 avatar Aug 20 '24 17:08 byi8220

The tests highlighted one small thing though, the fact that to fully stop using a dataloader in the middle you have to call dataloader.end(), but this might just be unavoidable. If the use case of StatefulDataLoader is to restart the entire program from a checkpoint maybe it's not a big issue

I believe this has been a known "issue" in accelerate (I've seen it pop up in other issues sparingly). Agreed that it's less of an issue here, since this is pretty much just called once at the start of training. As long as we have the state properly (which your tests check!) it's a different bug/issue to solve

Imo this might be better in a separate PR, once the code is checked in?

We tend to like full FC PR's that include doc updates. Less likely it'll be forgotten about and it's done all at once so users who want the bleeding edge can read immediately :)

muellerzr avatar Aug 20 '24 18:08 muellerzr

I believe this has been a known "issue" in accelerate (I've seen it pop up in other issues sparingly).

Well, I have no idea how to solve such a problem in python. In the C++ world this is what destructors and RAII are for, I guess.

We tend to like full FC PR's that include doc updates. Less likely it'll be forgotten about and it's done all at once so users who want the bleeding edge can read immediately :)

Sure, added a footnote in the docs about this feature.

Also since this feature is now stable in torchdata I added a requirement for torchdata>=0.8.0 in setup.py

byi8220 avatar Aug 20 '24 18:08 byi8220

I see. If Zach is fine with the proposed solution, then we're good.

sgtm

There is the del magic method in Python but let's not touch it.

I see. Destructors in python don't seem very reliable, but my knowledge of the python memory model isn't great.

byi8220 avatar Aug 21 '24 12:08 byi8220

Thanks!

I didn't find anything big, but a few minor things that could be improved.

I fixed the nits above, but I also made one, maybe important, change, done in https://github.com/huggingface/accelerate/pull/2895/commits/74e2f53d841f5701d476f3fe5f6df8f97ad82e5c

Basically, I literally realized just now that I have been delegating the work of iteration to the superclass, instead of the backing dataloader. That felt wrong, so I did the commit above.

To confirm, replacing super()->self.base_dataloader here is the sensible thing to do, right? Like I want to fully delegate everything I can to the base_dataloader, and the only reason it worked as written before is because of getattr spaghetti?

byi8220 avatar Aug 21 '24 15:08 byi8220

as a next step I'll work on getting this working with Accelerator.save_state/Accelerator.load_state today

muellerzr avatar Aug 22 '24 12:08 muellerzr