accelerate
accelerate copied to clipboard
Add early support for `torchdata.stateful_dataloader.StatefulDataLoader` within the `Accelerator`
What does this PR do?
Fixes https://github.com/huggingface/accelerate/issues/2859
This PR does the following:
- Added a new field
use_stateful_dataloaderinDataLoaderConfiguration. Passing this into the config makes it so that allDataLoaders prepared and returned by the Accelerator areStatefulDataLoaderobjects from the torchdata library - Create a class
DataLoaderAdapterwhich can wrap around and act as either PyTorch'sDataLoader, or other variants of it such asStatefulDataLoader - Refactor
DataLoaderShard,DataLoaderDispatcher, andSkipDataLoaderto inherit fromDataLoaderAdapterinstead ofDataLoader
Testing
Added new unit tests to test that StatefulDataLoader can be dropped in and loaded and saved from.
Caveats
- The
torchdatapackage may have conflicts withaccelerate, see https://github.com/huggingface/accelerate/issues/2894- However, if
torchdatais not installed, all existing tests pass, suggesting this is not regressive.
- However, if
torchdata.stateful_dataloader.StatefulDataLoaderis only available in the beta build oftorchdata, this is not a stable feature.- Adding another dependency (on a nightly package) means that almost none of the tests added in this PR is done underneath the existing images or imports.
- This has only been tested on my local workstation using a single GPU.
- The implementation of
DataLoaderAdapteris somewhat invasive and uses some questionable reflection
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [X] Did you read the contributor guideline, Pull Request section?
- [X] Was this discussed/approved via a Github issue (see above)
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [X] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@muellerzr
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Can you verify the tests pass on a multi-GPU system?
I have not written any tests specifically for multi-gpu. However, the PR as is passes on a 2 GPU cloud instance I spun up on GCP for the tests added so far.
Would the tests I've put into the PR so far be enough, or should I add specific multi-GPU tests? (I only have 1 GPU on my physical workstation so it costs me money to run these tests :sweat:)
Also would be good to add torchdata in the test requirements in the setup.py
I think StatefulDataLoader is still a beta feature and not officially included in any stable release. Their docs state it's only available in nightly builds for now, which are not found on PyPI.
According to https://stackoverflow.com/q/68809295 you really shouldn't be trying to specify extra_index_url in setup.py.
I'm not sure of a solution besides between trying to hack it out or waiting for the required features in torchdata to be released into stable.
Hi @byi8220, will be reviewing this in the next few days!
Is there an issue with just duplicating the code into each DataLoader type instead of introducing a new subclass?
There's one obvious way to get this feature working without some magical reflection tricks:
- Create equivalent
StatefulDataLoaderDispatcher,StatefulDataLoaderShard,StatefulSkipDataLoaderclasses which are identical except that it inherits fromStatefulDataLoaderand implement the state_dict functions. - Create helper functions such as
create_data_loader_dispatcher()which take in all the constructor args and pick whether to create a StatefulDataLoader or regular DataLoader - Replace all constructor calls with these helper functions.
Doing that would add a ton of code duplication, but it should be mostly functionally equivalent to what I have now. The only thing you would need to be careful with moving forward is now you must always use the helper functions to create these objects instead of the class constructors.
If you want to avoid evil dynamic reflection, I could refactor this PR to do the above instead without too much effort. It will make dataloader.py look extremely messy. I'll make a branch to demonstrate those changes, but it feels like I'm littering countless small checks of "do we have a stateful dataloader?' around the code.
There's some other ideas I thought about writing this PR, mostly involving composing DataLoader into these classes, but most of them felt kinda wrong or that they might break something.
Actually, trying to draft up the change I just mentioned reminded me why I gave up on duplicating the code in the first place. It's a mess. Issues I needed to workaround include:
- Since StatefulDataLoader inherits from DataLoader, attempting something like having StatefulDataLoaderDispatcher inherit from DataLoaderDispatcher creates a diamond problem. This leads to a ton of messy MRO issues which feel like too much mental effort to process over simply just not having this inheritance.
- By separating the types, there is a bunch of boilerplate handling code that needs to consider the exact details of the DataLoader being passed around.
- The existing issues with this being a feature depending on a nightly, unstable build of
torchdataremain (CI gap, needing dependency checking boilerplate, weirdly written tests, so on)
I'll defend my DataLoaderAdapter<T> : T solution since I find it elegant, but I recognize it's horribly confusing and invokes some pretty cursed stuff, so I'm in full support of getting rid of it for something more overt and readable.
If you want to see what it would look like, I've made those changes in a separate branch.
Here's the diff between this PR and the above branch: https://github.com/byi8220/accelerate/compare/stateful-dataloader...byi8220:accelerate:stateful-dataloader-2)
@muellerzr If the code above looks better, I can merge those into this PR's working branch and move on with that instead.
TBH I'm not sure if I like either. Let me work on this today and tommorow and see where I end up. Whichever solution winds up being similar to yours (unless a third one pops up) we'll go with it.
TBH I'm not sure if I like either.
It's tricky. My understanding of the problem is this: We want to do something like this"If use_stateful_dataloader==True, then create a StatefulDataLoader instead of a DataLoader."
However, based on how classes like DataLoaderShard are currently implemented, what we have to do in practice is "If use_stateful_dataloader==True, then make this class inherit from StatefulDataLoader instead of DataLoader." This is a lot more awkward to code.
I have one more idea for a solution, which is to have those classes create a base_dataloader and manually code the passthrough to methods and properties. I feel like this is a bit fragile, and isn't too different from my original solution.
@byi8220 can you resolve the PR's and then I think we're okay to merge this.
As a final step, we likely want to update save_state/load_state to resume the dataloaders at this point.
@BenjaminBossan Thanks for the review. Just addressed the comments on the PR.
The end result is still something that I'm afraid will one day cause a hard to debug issue, but I can't say what exactly would be a better solution.
The ultimate intent of this code is something like "Sometimes I want a DataLoaderDispatcher that inherits from DataLoader, but other times I want a DataLoaderDispatcher that inherits from StatefulDataLoader."
Imo, the less magical alternative would be to explicitly duplicate each DataLoader derivative that accelerate defines into a stateful version. I.e. manually create the classes StatefulDataLoaderDispatcher, StatefulDataLoaderShard, StatefulSkipDataLoader. I wrote up this alternative in a separate branch (diffed by https://github.com/byi8220/accelerate/compare/stateful-dataloader...byi8220:accelerate:stateful-dataloader-2), but it leads to quite a lot of code duplication and also looks messy.
I have to admit I only skimmed the tests but they look very well done, so together with the existing ones they should hopefully avoid regressions.
I've tested this locally on my 1 GPU home workstation + a 2xGPU cloud instance (which costs me a few dollars every time I want to run the test suite :disappointed: ...) The fact that all tests pass for me when not using this feature, regardless of if the required torchdata version is not installed gives confidence that it's not causing a breaking regression.
This is my first real PR into accelerate, so I added the sanity and happy test cases I could think of based on my limited context, so I might have just been guessing on what's sufficient.
The tests highlighted one small thing though, the fact that to fully stop using a dataloader in the middle you have to call dataloader.end(), but this might just be unavoidable. If the use case of StatefulDataLoader is to restart the entire program from a checkpoint maybe it's not a big issue: https://github.com/huggingface/accelerate/pull/2895/files#diff-68b278b14afa2e1ea337bb5e13d122f6d074c8bf0f0b83bef779eac6f4ba7f9aR724-R726
One thing I would like to see is an addition to the docs to explain what stateful data loaders are, why users may want to use them, and how they can use them.
Imo this might be better in a separate PR, once the code is checked in?
The tests highlighted one small thing though, the fact that to fully stop using a dataloader in the middle you have to call dataloader.end(), but this might just be unavoidable. If the use case of StatefulDataLoader is to restart the entire program from a checkpoint maybe it's not a big issue
I believe this has been a known "issue" in accelerate (I've seen it pop up in other issues sparingly). Agreed that it's less of an issue here, since this is pretty much just called once at the start of training. As long as we have the state properly (which your tests check!) it's a different bug/issue to solve
Imo this might be better in a separate PR, once the code is checked in?
We tend to like full FC PR's that include doc updates. Less likely it'll be forgotten about and it's done all at once so users who want the bleeding edge can read immediately :)
I believe this has been a known "issue" in accelerate (I've seen it pop up in other issues sparingly).
Well, I have no idea how to solve such a problem in python. In the C++ world this is what destructors and RAII are for, I guess.
We tend to like full FC PR's that include doc updates. Less likely it'll be forgotten about and it's done all at once so users who want the bleeding edge can read immediately :)
Sure, added a footnote in the docs about this feature.
Also since this feature is now stable in torchdata I added a requirement for torchdata>=0.8.0 in setup.py
I see. If Zach is fine with the proposed solution, then we're good.
sgtm
There is the del magic method in Python but let's not touch it.
I see. Destructors in python don't seem very reliable, but my knowledge of the python memory model isn't great.
Thanks!
I didn't find anything big, but a few minor things that could be improved.
I fixed the nits above, but I also made one, maybe important, change, done in https://github.com/huggingface/accelerate/pull/2895/commits/74e2f53d841f5701d476f3fe5f6df8f97ad82e5c
Basically, I literally realized just now that I have been delegating the work of iteration to the superclass, instead of the backing dataloader. That felt wrong, so I did the commit above.
To confirm, replacing super()->self.base_dataloader here is the sensible thing to do, right? Like I want to fully delegate everything I can to the base_dataloader, and the only reason it worked as written before is because of getattr spaghetti?
as a next step I'll work on getting this working with Accelerator.save_state/Accelerator.load_state today