Generate: Add new decoding strategy "DoLa" in `.generate()`
What does this PR do?
Fixes #29524
We add the support for a new decoding strategy proposed in a recent paper of ICLR 2024. The main revisions are in src/transformers/generation/utils.py and src/transformers/generation/configuration_utils.py
We also update the documentation and add the test code. Run the test by:
CUDA_VISIBLE_DEVICES=0 python examples/pytorch/text-generation/run_generation_dola.py --model_name_or_path huggyllama/llama-7b --model_type llama --dola_layers 'low'
Before submitting
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Was this discussed/approved via a Github issue or the forum? Yes, in #29524
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests? Yes, in examples/pytorch/text-generation/run_generation_dola.py
Who can review?
@gante is the main contributor of the part of .generate() function, which this PR focuses on.
Hi @gante !
Thanks so much for your suggestions! I spent some time to add the code for test cases, and fixed the issues you mentioned. All the CI checks were passed as well. Can you take a look at my latest commits of the code?
Please let me know if you have any other concerns or suggestions for me to fix! I would be happy to address any of the issues you may have! 🤗
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Hi @gante !
Thanks so much for your great suggestions! I have fixed all the issues you mentioned. Just let me know if you have any other concerns or suggestions! Thanks for requesting a review from the core maintainer! 🤗
Hi @gante !
While waiting for the core maintainer's approval, I found that the validation of the parameter ranges in the generation config mainly happens in tsrc/transformers/generation/configuration_utils.py instead of src/transformers/generation/utils.py. Thus, I simply moved the warning of repetition penalty of dola generation to configuration_utils.py, and the warning will also only occur once!
However, after I committed the new code. A test case of XLM model failed, and it seems to have nothing to do with my commit. The failed case seems related to #29297
I tried syncing with the upstream but it didn't solve the issue. I wonder if you know what's the reason for this failed test case. Sorry for bothering you again!
Some tests failed!
============================= FAILURES SHORT STACK =============================
____________________ XLMModelTest.test_batching_equivalence ____________________
tests/test_modeling_common.py:745: in recursive_check
self.assertTrue(
E AssertionError: tensor(False) is not true : Batched and Single row outputs are not equal in XLMForQuestionAnswering for key=end_top_index. Difference=1.
FAILED tests/models/xlm/test_modeling_xlm.py::XLMModelTest::test_batching_equivalence - AssertionError: tensor(False) is not true : Batched and Single row outputs are not equal in XLMForQuestionAnswering for key=end_top_index. Difference=1.
Exited with code exit status 255
The failed test case was solved after syncing with the upstream! Please ignore my previous comment. It's ready to merge now!
Hi @amyeroberts !
This PR is ready to merge after some iterations! Would you be able to review it and give me any suggestions you have? Thanks a lot for the help! 🤗
Hi @amyeroberts !
Thanks so much for all of your great suggestions! They are very helpful and they improved my code and the test cases! I have tried my best to fix all the issues you mentioned above. Let me know if there are still concerns or suggestions so I can address them further! 🤗
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@voidism are you intending to continue the PR? 🤗 Or do you need a hand?
Hi @gante
Sorry that I was busy with my midterm for the past few weeks 😔 so I forgot to fix this for a while... I will continue fixing the PR this or next week! Thanks for the reminder and sorry for the delay!
@voidism no worries, focus on your midterms 💪 we'll be here when you're ready to continue 🙌