transformers
transformers copied to clipboard
make sure to disable gradients for integer tensor
What does this PR do?
When doing 8bit lora's with deepspeed zero3, the Int8Params should have requires_grad=False otherwise the following error occurs:
[rank0]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/bitsandbytes/nn/modules.py", line 574, in __new__
[rank0]: obj = torch.Tensor._make_subclass(cls, data, requires_grad)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Only Tensors of floating point and complex dtype can require gradients
this used to work several months ago
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
cc @ArthurZucker for final :)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@ArthurZucker @muellerzr I rebased this against latest transformers and tested on my end