optimum-habana icon indicating copy to clipboard operation
optimum-habana copied to clipboard

Add support for contrastive search

Open skavulya opened this issue 9 months ago • 11 comments

What does this PR do?

Adds support for contrastive search for static and dynamic inputs, and low memory configs. Also adds support for GPT2DoubleHeadsModel.

Fixes the tests below:

GAUDI2_CI=1  RUN_SLOW=1 python -m pytest  tests/transformers/tests/models/t5/test_modeling_t5.py::T5ModelIntegrationTests::test_contrastive_search_t5 -s -v

GAUDI2_CI=1  RUN_SLOW=1 python -m pytest  tests/transformers/tests/models/gpt2/test_modeling_gpt2.py -s -v -k test_contrastive_search_gpt2

GAUDI2_CI=1  RUN_SLOW=1 python -m pytest tests/transformers/tests/models/gpt2/test_modeling_gpt2.py -s -v -k test_batch_generation_2heads

GAUDI2_CI=1  RUN_SLOW=1 python -m pytest tests/transformers/tests/models/gpt2/test_modeling_gpt2.py -s -v -k test_contrastive_generate_low_memory

GAUDI2_CI=1  RUN_SLOW=1 python -m pytest tests/transformers/tests/models/gpt2/test_modeling_gpt2.py -s -v -k test_contrastive_generate_dynamic_shapes

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ ] Did you make sure to update the documentation with your changes?
  • [ ] Did you write any new necessary tests?

skavulya avatar May 02 '24 17:05 skavulya

Can you add an example using run_generation.py with same config of input/output, but one with greedy and one with llama, and post the perf and the sentences generated.

just want to document that perf difference between greedy and contrastive, and make sure teh sentence quality is different from greedy

@ssarkar2 I'll add the contrastive search arguments to run_generation.py. Would you like me to add the examples of the generated text to the README or paste the output here in the comments?

skavulya avatar May 08 '24 16:05 skavulya

Can you add an example using run_generation.py with same config of input/output, but one with greedy and one with llama, and post the perf and the sentences generated. just want to document that perf difference between greedy and contrastive, and make sure teh sentence quality is different from greedy

@ssarkar2 I'll add the contrastive search arguments to run_generation.py. Would you like me to add the examples of the generated text to the README or paste the output here in the comments?

I think you just posted a sample output in the discussion here, that should be good enough for documenting it.

ssarkar2 avatar May 14 '24 05:05 ssarkar2

@ssarkar2 @libinta @jiminha @regisss I have updated the code to use transformers v4.40. Please review and let me know if there are additional changes you would like me to make.

skavulya avatar Jun 04 '24 21:06 skavulya

@skavulya can you rebase?

libinta avatar Jun 25 '24 01:06 libinta

@libinta @ssarkar2 I rebased the PR. I also updated contrastive search to work with bucketing.

skavulya avatar Jul 01 '24 19:07 skavulya

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@skavulya When running make style, ruff found 2 errors:

optimum/habana/transformers/generation/utils.py:1958:24: F821 Undefined name `prev_idx`
     |
1956 |                 if model_kwargs.get("token_idx_cpu") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
1957 |                     idx = (model_kwargs.get("token_idx_cpu") - 1) // bucket_size
1958 |                     if prev_idx != idx:
     |                        ^^^^^^^^ F821
1959 |                         model_kwargs["cache_idx"] = (idx + 1) * bucket_size
1960 |                         prev_idx = idx
     |

optimum/habana/transformers/generation/utils.py:1960:25: F841 Local variable `prev_idx` is assigned to but never used
     |
1958 |                     if prev_idx != idx:
1959 |                         model_kwargs["cache_idx"] = (idx + 1) * bucket_size
1960 |                         prev_idx = idx
     |                         ^^^^^^^^ F841
1961 |                 else:
1962 |                     model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]

I'm not sure exactly how to fix them, what is the purpose of prev_idx?

regisss avatar Jul 11 '24 17:07 regisss

Thanks @regisss I fixed the issue with prev_idx. It is used for bucketing.

skavulya avatar Jul 16 '24 00:07 skavulya

Thanks @regisss I added the test for contrastive search.

@regisss @ssarkar2 When adding the test, I noticed that the throughput is really low compared to A100 for gpt2 top_k=4, penalty_alpha=0.5: A100: 154.84 tokens/sec vs. Gaudi2: 98.33 tokens/sec. I have narrowed down the slowest code to the following and need guidance on how to optimize these sections of code:

https://github.com/skavulya/optimum-habana/blob/48f925061fff14ac3b6b2707fe9c782122f7e000/optimum/habana/transformers/generation/utils.py#L1852-L1855

next_decoder_hidden_states = ()
for layer in full_hidden_states:
    layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
    next_decoder_hidden_states += (layer,)

There is a similar section of code here: https://github.com/skavulya/optimum-habana/blob/48f925061fff14ac3b6b2707fe9c782122f7e000/optimum/habana/transformers/generation/utils.py#L1877-L1925

skavulya avatar Jul 22 '24 18:07 skavulya

@skavulya can you please fix code style?

yeonsily avatar Jul 25 '24 18:07 yeonsily

@yeonsily I fixed the code style.

@regisss @ssarkar2 I improved the performance of the code by creating the range indices on hpu instead of gpu. The performance of gpt2 improved from 98 tokens/sec to 173 tokens/sec for the following command:

python run_generation.py --model_name_or_path gpt2 --use_hpu_graphs --use_kv_cache --max_new_tokens 100 --top_k 4 --penalty_alpha 0.5 --prompt "Here is my prompt"

skavulya avatar Jul 26 '24 18:07 skavulya

@regisss I saw the PR for upgrade to 4.43 https://github.com/huggingface/optimum-habana/pull/1163 Would you like me to upgrade this PR too?

skavulya avatar Aug 02 '24 17:08 skavulya

@regisss I saw the PR for upgrade to 4.43 #1163 Would you like me to upgrade this PR too?

That would be great if you have time to do it! Otherwise I'll try to do it in the next few days.

regisss avatar Aug 02 '24 21:08 regisss

Thanks @regisss!

skavulya avatar Aug 07 '24 19:08 skavulya