fix QAT version dependency
Context
What is the purpose of this PR? Is it to
- [ ] add a new feature
- [x] fix a bug
- [ ] update tests and/or documentation
- [ ] other (please add here)
We updated torchtune to use torchao 0.4.0. It breaks unless user has pytorch 2.4.0. In our scripts, we were using import guards:
https://github.com/felipemello1/torchtune/blob/04ccbf2601653e0e2cceb75e59394df5517d26e3/torchtune/utils/quantization.py#L12
However, "TORCH_VERSION_AFTER_2_4" actually didnt include 2_4. This was fixed in torchao here: https://github.com/pytorch/ao/pull/684, but it wont be available to us until their next release.
After updating TorchAO and the import guards, another error was raised:
[rank0]: File "/home/felipemello/.conda/envs/test_ao/lib/python3.10/site-packages/torchao/quantization/prototype/qat/utils.py", line 42, in forward
[rank0]: assert input.dtype == torch.float32
This is because QAT recipe now requires the model to be in float32. More context here: https://github.com/pytorch/ao/blob/0b66ff01ab6ba4094823b8cb134ab5b5a744d73a/torchao/quantization/prototype/qat/utils.py#L39
Changing the QAT recipes to have dtpye = fp32 solved it
Changelog
- Update torchao=0.4.0
- Remove pin from Numpy (this is unrelated to this PR, but it was something we needed to do, so it made sense to test everything together https://github.com/pytorch/torchtune/issues/1344)
- Temporarily change the import guards. They MUST be updated with the next torchao release. (should I add some assertion that checks torchao__version__ <= 0.4.0?)
- QAT configs use fp32
Test plan
I was able to run the code below. But I did not try to compare with previous version.
tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1333
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit 04ccbf2601653e0e2cceb75e59394df5517d26e3 with merge base 6a7951f1cdd0b56a9746ef5935106989415f50e3 ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
cc @msaroufim
@felipemello1 Would you mind just also double checking our version guards for AO? We'll need to be extra careful around this now that we're relaxing this pin.
CI will only catch our tests, which are a subset of how our library is used.
Would you mind just also double checking our version guards for AO?
Not sure what you mean by it. Can you add an example or a link of what you would like me to do?
This change allows for stable builds of torchtune to break in the future. If there is a stable package for torchtune that works fine with torchao, and then torchao releases a new stable package with bc breaking changes, our existing stable packages would try to install the new torchao package and break. We need to keep torchao pinned and then use a tool like dependabot to keep the pinned version up to date.
For CI we should decide if we want to pin to the latest version of PyTorch and possible have separate tests for PyTorch nightlies with unpinned PyTorch libraries. @ebsmothers
If there is a stable package for torchtune that works fine with torchao, and then torchao releases a new stable package with bc breaking changes, our existing stable packages
TLDR: just update to 0.4
When users pip install torchtune the ao version should be pinned so the official release packages are guaranteed to work. If then users choose to upgrade AO there is no guarantee things will work (same as PyTorch) but we'll try our best not to break things for no good reason
In tune CI you should always be testing all your latest stable dependencies and all your latest nightly dependencies. We should never be catching BC issues at release time but at nightly CI time, that way upgrading a stable release can be a safe activity. Personally I wouldn't wait more than a few days after an official AO release to make an upgrade
@felipemello1 How urgent is upgrading torchao? I will submit a fix in torchao itself to remove that assertion, but I don't think we want to change the default precision in the QAT recipes. We have another release planned in early september, but if that's too late for you maybe we can do a 0.4.1 release with the fix?
https://github.com/pytorch/ao/pull/692
We have another release planned in early september, but if that's too late for you maybe we can do a 0.4.1 release with the fix?
Thanks for the fix @andrewor14!
If making a release is not a huge effort, this would solve multiple problems: Our regression error, the import version, and the dtype. Itt would be convenient. However, i dont think that we have a huge number of users using QAT, so waiting for september wouldnt be terrible.
In summary, if making the release is easy, that would be very neat. But if its going to take you a considerable amount of time, we can wait 2 weeks.
Is this on hold until next torchao release then? And if so are we gonna just bump to 0.5.0? If so let's make sure that our nightly CI is green before that release
Is this on hold until next torchao release then?
thats my understanding