transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Add stop sequence to text generation pipeline

Open KMFODA opened this issue 1 year ago • 10 comments

What does this PR do?

As per the conversation in https://github.com/huggingface/transformers/issues/17562, creating this draft PR to add a stop_sequence option to text generation pipelines.

Fixes # (issue)

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ ] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

@Narsil

Models:

All

Library:

  • text generation: @patrickvonplaten
  • pipelines: @LysandreJik

KMFODA avatar Aug 03 '22 08:08 KMFODA

The documentation is not available anymore as the PR was closed or merged.

Hey @Narsil. I've managed to get this working for greedy decoding and multimodal sampling. For beam-search, what would be the best approach to deal with a stop_sequence? I've assumed that if a stop_sequence appears in any of the beams then we stop the generation process.

Should it instead be that we wait until each beam reaches the stop_sequence or any other stopping criteria before stopping the generation process?

KMFODA avatar Aug 08 '22 09:08 KMFODA

Should it instead be that we wait until each beam reaches the stop_sequence or any other stopping criteria before stopping the generation process?

@KMFODA I think eos_token_id is already handled for beam search, see my comment on the StoppingCriteria.

I will let others comment on the best way to do this in .generate but I think we don't need the criteria, just let eos_token_id regular logic apply (it's handled separately from StoppingCriteria).

Narsil avatar Aug 08 '22 12:08 Narsil

For the tests removing the breakpoint should help then for code quality.

pip install -e .[quality]
make fixup

Should do the trick.

Narsil avatar Aug 08 '22 12:08 Narsil

@Narsil @KMFODA I'm in favor of moving it to a StoppingCriteria, so that all conditions that can terminate generation fall under the same class. However, it should be noted that it is not a requirement to complete the issue, i.e. to add a stop sequence to the text generation pipeline :P

It is already implemented on the multiple generation strategies (e.g. here for greedy search). Also, the existing implementation is different from the current PR -- the existing implementation only checks whether the eos_token is present in newly generated tokens. This is because models like GPT-2 often set pad_token_id to eos_token_id, and we don't want the pad tokens to trigger this condition.

gante avatar Aug 08 '22 14:08 gante

Thanks @Narsil @gante. Okay so for the sake of deploying iteratively I've removed the eos_token_id from the StoppingCriteria and will add it as a separate PR.

I've added a test for the stop_sequence being fed in at the pipeline level. When @Narsil's comment around wether the stop sequence should be handled in the pipeline or in the generation_kwargs is addressed I can alter this test accordingly.

KMFODA avatar Aug 09 '22 06:08 KMFODA

We should implement stop_sequence only once (probably in generate) but we could have 2 tests if you want to test the full pipeline too. (Probably in tests/pipelines/test_pipelines_text_generation.py for instance.)

If we were to move stop_sequence to be in generate wouldn't we have to tokenise it first. In that case what's the reasoning behind feeding it as a stop_sequence instead of a eos_token_id?

KMFODA avatar Aug 10 '22 07:08 KMFODA

If we were to move stop_sequence to be in generate wouldn't we have to tokenise it first. In that case what's the reasoning behind feeding it as a stop_sequence instead of a eos_token_id?

You're entirely right, oversight on my part. eos_token_id already does the job. So we just need to implement stop_sequence in the pipeline to tokenize the stop_sequence and produce the eos_token_id and just feed it to generate. So no additional code in generate should be needed actually.

Sorry, failed to see that.

Narsil avatar Aug 12 '22 14:08 Narsil

No problem I've just moved the stop_sequence back to the pipeline function and added the tests you requested in the tests/pipelines/test_pipelines_text_generation.py folder. This should make this PR ready for review now.

When I was playing with the stop_sequence though I found that sometime when I add a specific stop_sequence the output changes and avoids mentioning the word entirely. I don't have live examples now but I just wanted to check if this is normal behaviour? If not I can find examples on public models and share it in a different issue.

KMFODA avatar Aug 15 '22 11:08 KMFODA

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Sep 18 '22 15:09 github-actions[bot]

@KMFODA I think your PR is almost ready to be merged! Would you like to try to fix the final problems and apply the review suggestions? :-)

patrickvonplaten avatar Sep 27 '22 11:09 patrickvonplaten

Hey @patrickvonplaten. My apologies I was out sick over the past month. I worked on the suggestions now. Hopefully this should be good to merge now but if not let me know!

KMFODA avatar Sep 28 '22 14:09 KMFODA

I'm happy with the PR, except for the EndOfStringCriteria class -- it is not being used, and it is not a good practice to add unused classes/functions.

@KMFODA can you remove it for now, and perhaps reintroduce it in a follow-up PR (with use cases)? :)

gante avatar Sep 30 '22 09:09 gante

Hi @gante yes of course. I had removed it locally but somehow the changes didn't push through with one of the commits. Forced changed it now. Hopefully that looks good now :).

KMFODA avatar Sep 30 '22 12:09 KMFODA