optimum-habana
optimum-habana copied to clipboard
Fix decoder only generation
What does this PR do?
Fixes issues with decoder-only generation. Multiple issues were found and corrected.
root@idc705326-7:/optimum-habana# GAUDI2_CI=1Â RUN_SLOW=1 python -m pytest tests/transformers/tests/models/ -k test_generate_from_inputs_embeds_decoder_only
============================================================================================================== test session starts ==============================================================================================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.4.0
rootdir: /optimum-habana
configfile: setup.cfg
collected 1218 items / 1210 deselected / 8 selected
tests/transformers/tests/models/bert/test_modeling_bert.py . [ 12%]
tests/transformers/tests/models/falcon/test_modeling_falcon.py . [ 25%]
tests/transformers/tests/models/gpt2/test_modeling_gpt2.py . [ 37%]
tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py . [ 50%]
tests/transformers/tests/models/gptj/test_modeling_gptj.py . [ 62%]
tests/transformers/tests/models/llama/test_modeling_llama.py . [ 75%]
tests/transformers/tests/models/roberta/test_modeling_roberta.py . [ 87%]
tests/transformers/tests/models/t5/test_modeling_t5.py . [100%]
=============================================================================================================== warnings summary ================================================================================================================
../usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:462
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:462: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
../usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:319
../usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:319
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:319: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
tests/transformers/tests/test_modeling_common.py:2044
/optimum-habana/tests/transformers/tests/test_modeling_common.py:2044: PytestUnknownMarkWarning: Unknown pytest.mark.accelerate_tests - is this a typo? You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
@mark.accelerate_tests
tests/transformers/tests/test_modeling_common.py:2082
/optimum-habana/tests/transformers/tests/test_modeling_common.py:2082: PytestUnknownMarkWarning: Unknown pytest.mark.accelerate_tests - is this a typo? You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
@mark.accelerate_tests
tests/transformers/tests/test_modeling_common.py:2118
/optimum-habana/tests/transformers/tests/test_modeling_common.py:2118: PytestUnknownMarkWarning: Unknown pytest.mark.accelerate_tests - is this a typo? You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
@mark.accelerate_tests
tests/transformers/tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_generate_from_inputs_embeds_decoder_only
tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py::GPTNeoXModelTest::test_generate_from_inputs_embeds_decoder_only
tests/transformers/tests/models/gptj/test_modeling_gptj.py::GPTJModelTest::test_generate_from_inputs_embeds_decoder_only
tests/transformers/tests/models/llama/test_modeling_llama.py::LlamaModelTest::test_generate_from_inputs_embeds_decoder_only
/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1178: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
warnings.warn(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================================================================== 8 passed, 1210 deselected, 10 warnings in 11.21s ================================================================================================
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?