optimum-habana icon indicating copy to clipboard operation
optimum-habana copied to clipboard

Fix decoder only generation

Open tjs-intel opened this issue 9 months ago • 1 comments

What does this PR do?

Fixes issues with decoder-only generation. Multiple issues were found and corrected.

root@idc705326-7:/optimum-habana# GAUDI2_CI=1  RUN_SLOW=1 python -m pytest tests/transformers/tests/models/ -k test_generate_from_inputs_embeds_decoder_only
============================================================================================================== test session starts ==============================================================================================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.4.0
rootdir: /optimum-habana
configfile: setup.cfg
collected 1218 items / 1210 deselected / 8 selected

tests/transformers/tests/models/bert/test_modeling_bert.py .                                                                                                                                                                              [ 12%]
tests/transformers/tests/models/falcon/test_modeling_falcon.py .                                                                                                                                                                          [ 25%]
tests/transformers/tests/models/gpt2/test_modeling_gpt2.py .                                                                                                                                                                              [ 37%]
tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py .                                                                                                                                                                      [ 50%]
tests/transformers/tests/models/gptj/test_modeling_gptj.py .                                                                                                                                                                              [ 62%]
tests/transformers/tests/models/llama/test_modeling_llama.py .                                                                                                                                                                            [ 75%]
tests/transformers/tests/models/roberta/test_modeling_roberta.py .                                                                                                                                                                        [ 87%]
tests/transformers/tests/models/t5/test_modeling_t5.py .                                                                                                                                                                                  [100%]

=============================================================================================================== warnings summary ================================================================================================================
../usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:462
  /usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:462: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
    _torch_pytree._register_pytree_node(

../usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:319
../usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:319
  /usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:319: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
    _torch_pytree._register_pytree_node(

tests/transformers/tests/test_modeling_common.py:2044
  /optimum-habana/tests/transformers/tests/test_modeling_common.py:2044: PytestUnknownMarkWarning: Unknown pytest.mark.accelerate_tests - is this a typo?  You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
    @mark.accelerate_tests

tests/transformers/tests/test_modeling_common.py:2082
  /optimum-habana/tests/transformers/tests/test_modeling_common.py:2082: PytestUnknownMarkWarning: Unknown pytest.mark.accelerate_tests - is this a typo?  You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
    @mark.accelerate_tests

tests/transformers/tests/test_modeling_common.py:2118
  /optimum-habana/tests/transformers/tests/test_modeling_common.py:2118: PytestUnknownMarkWarning: Unknown pytest.mark.accelerate_tests - is this a typo?  You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
    @mark.accelerate_tests

tests/transformers/tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_generate_from_inputs_embeds_decoder_only
tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py::GPTNeoXModelTest::test_generate_from_inputs_embeds_decoder_only
tests/transformers/tests/models/gptj/test_modeling_gptj.py::GPTJModelTest::test_generate_from_inputs_embeds_decoder_only
tests/transformers/tests/models/llama/test_modeling_llama.py::LlamaModelTest::test_generate_from_inputs_embeds_decoder_only
  /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1178: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================================================================== 8 passed, 1210 deselected, 10 warnings in 11.21s ================================================================================================

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ ] Did you make sure to update the documentation with your changes?
  • [ ] Did you write any new necessary tests?

tjs-intel avatar May 03 '24 18:05 tjs-intel

@regisss I think too many reviewers were tagged and this got lost in the weeds. Can I please have this PR reviewed?

tjs-intel avatar May 28 '24 16:05 tjs-intel

Please sync your PR with main/upstream and fix any merge conflicts. Thank you.

emascarenhas avatar Sep 03 '24 14:09 emascarenhas

@ssarkar2 @regisss Please review this fix for inputs_embeds with static shapes.

skavulya avatar Sep 03 '24 21:09 skavulya

@skavulya

  • if the PR is ready for review, I can start on that. Please ping me here.
  • please run GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/transformers/tests/models/ -s -v before and after changes and make sure there is no new one is introduced.
  • We need to run more CI to make sure everything is working as expected

yafshar avatar Sep 05 '24 14:09 yafshar

Thanks @yafshar. It is ready for review. I'll run the tests to confirm no regressions

skavulya avatar Sep 05 '24 15:09 skavulya

@yafshar Here are the test results: GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/transformers/tests/models/ -s -v main: 26 failed, 873 passed, 326 skipped, 1 xpassed, 58 warnings in 1639.18s (0:27:19) = this pr: 21 failed, 878 passed, 326 skipped, 1 xpassed, 58 warnings in 1330.34s (0:22:10) =

skavulya avatar Sep 05 '24 17:09 skavulya

@yafshar Here are the results of additional tests

make slow_tests_text_generation_example Main: 1 failed, 25 passed, 4 skipped This pr: 1 failed, 25 passed, 4 skipped

Perf test to check that text generation with input_ids is not impacted by fix. python run_generation.py --model_name_or_path openai-community/gpt2 --use_kv_cache --max_new_tokens 128 --max_input_tokens=128 --batch_size 4 --use_hpu_graph --bf16 --no-ignore_eos Tokens/sec for 3 runs Main: 2186.605, 2222.230, 2201.932 This PR: 2222.338 2240.891 2237.272

skavulya avatar Sep 05 '24 21:09 skavulya

@skavulya please ignore my last comments if you do not like it. It is a personal style, I do not like the extra conditional check. The first comment for .all please fix that

yafshar avatar Sep 05 '24 22:09 yafshar

Thanks @yafshar. I'll update the code. Your suggestions make it cleaner.

skavulya avatar Sep 06 '24 16:09 skavulya

@skavulya thanks, everything sounds good to me. Would you please add a final test to make sure we do not miss anything. please run GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/transformers/tests/models/ -s -v before and after changes and make sure there is no new one is introduced.

yafshar avatar Sep 09 '24 21:09 yafshar

Thanks @yafshar. I rebased and re-ran the tests.

Main: 25 failed, 874 passed, 326 skipped, 1 xpassed, 59 warnings in 1351.61s (0:22:31) = This PR 20 failed, 879 passed, 326 skipped, 1 xpassed, 59 warnings in 1474.58s (0:24:34) =

skavulya avatar Sep 10 '24 20:09 skavulya

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.