optimum-habana
optimum-habana copied to clipboard
Fix decoder only generation
What does this PR do?
Fixes issues with decoder-only generation. Multiple issues were found and corrected.
root@idc705326-7:/optimum-habana# GAUDI2_CI=1Â RUN_SLOW=1 python -m pytest tests/transformers/tests/models/ -k test_generate_from_inputs_embeds_decoder_only
============================================================================================================== test session starts ==============================================================================================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.4.0
rootdir: /optimum-habana
configfile: setup.cfg
collected 1218 items / 1210 deselected / 8 selected
tests/transformers/tests/models/bert/test_modeling_bert.py . [ 12%]
tests/transformers/tests/models/falcon/test_modeling_falcon.py . [ 25%]
tests/transformers/tests/models/gpt2/test_modeling_gpt2.py . [ 37%]
tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py . [ 50%]
tests/transformers/tests/models/gptj/test_modeling_gptj.py . [ 62%]
tests/transformers/tests/models/llama/test_modeling_llama.py . [ 75%]
tests/transformers/tests/models/roberta/test_modeling_roberta.py . [ 87%]
tests/transformers/tests/models/t5/test_modeling_t5.py . [100%]
=============================================================================================================== warnings summary ================================================================================================================
../usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:462
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:462: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
../usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:319
../usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:319
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:319: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
tests/transformers/tests/test_modeling_common.py:2044
/optimum-habana/tests/transformers/tests/test_modeling_common.py:2044: PytestUnknownMarkWarning: Unknown pytest.mark.accelerate_tests - is this a typo? You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
@mark.accelerate_tests
tests/transformers/tests/test_modeling_common.py:2082
/optimum-habana/tests/transformers/tests/test_modeling_common.py:2082: PytestUnknownMarkWarning: Unknown pytest.mark.accelerate_tests - is this a typo? You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
@mark.accelerate_tests
tests/transformers/tests/test_modeling_common.py:2118
/optimum-habana/tests/transformers/tests/test_modeling_common.py:2118: PytestUnknownMarkWarning: Unknown pytest.mark.accelerate_tests - is this a typo? You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
@mark.accelerate_tests
tests/transformers/tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_generate_from_inputs_embeds_decoder_only
tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py::GPTNeoXModelTest::test_generate_from_inputs_embeds_decoder_only
tests/transformers/tests/models/gptj/test_modeling_gptj.py::GPTJModelTest::test_generate_from_inputs_embeds_decoder_only
tests/transformers/tests/models/llama/test_modeling_llama.py::LlamaModelTest::test_generate_from_inputs_embeds_decoder_only
/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1178: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
warnings.warn(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================================================================== 8 passed, 1210 deselected, 10 warnings in 11.21s ================================================================================================
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?
@regisss I think too many reviewers were tagged and this got lost in the weeds. Can I please have this PR reviewed?
Please sync your PR with main/upstream and fix any merge conflicts. Thank you.
@ssarkar2 @regisss Please review this fix for inputs_embeds
with static shapes.
@skavulya
- if the PR is ready for review, I can start on that. Please ping me here.
- please run
GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/transformers/tests/models/ -s -v
before and after changes and make sure there is no new one is introduced. - We need to run more CI to make sure everything is working as expected
Thanks @yafshar. It is ready for review. I'll run the tests to confirm no regressions
@yafshar Here are the test results:
GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/transformers/tests/models/ -s -v
main: 26 failed, 873 passed, 326 skipped, 1 xpassed, 58 warnings in 1639.18s (0:27:19) =
this pr: 21 failed, 878 passed, 326 skipped, 1 xpassed, 58 warnings in 1330.34s (0:22:10) =
@yafshar Here are the results of additional tests
make slow_tests_text_generation_example
Main: 1 failed, 25 passed, 4 skipped
This pr: 1 failed, 25 passed, 4 skipped
Perf test to check that text generation with input_ids is not impacted by fix.
python run_generation.py --model_name_or_path openai-community/gpt2 --use_kv_cache --max_new_tokens 128 --max_input_tokens=128 --batch_size 4 --use_hpu_graph --bf16 --no-ignore_eos
Tokens/sec for 3 runs
Main: 2186.605, 2222.230, 2201.932
This PR: 2222.338 2240.891 2237.272
@skavulya please ignore my last comments if you do not like it. It is a personal style, I do not like the extra conditional check. The first comment for .all
please fix that
Thanks @yafshar. I'll update the code. Your suggestions make it cleaner.
@skavulya thanks, everything sounds good to me. Would you please add a final test to make sure we do not miss anything. please run GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/transformers/tests/models/ -s -v
before and after changes and make sure there is no new one is introduced.
Thanks @yafshar. I rebased and re-ran the tests.
Main: 25 failed, 874 passed, 326 skipped, 1 xpassed, 59 warnings in 1351.61s (0:22:31) = This PR 20 failed, 879 passed, 326 skipped, 1 xpassed, 59 warnings in 1474.58s (0:24:34) =
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.