transformers
transformers copied to clipboard
Add stop sequence to text generation pipeline
What does this PR do?
As per the conversation in https://github.com/huggingface/transformers/issues/17562, creating this draft PR to add a stop_sequence option to text generation pipelines.
Fixes # (issue)
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
@Narsil
Models:
All
Library:
- text generation: @patrickvonplaten
- pipelines: @LysandreJik
The documentation is not available anymore as the PR was closed or merged.
Hey @Narsil. I've managed to get this working for greedy decoding and multimodal sampling. For beam-search, what would be the best approach to deal with a stop_sequence? I've assumed that if a stop_sequence appears in any of the beams then we stop the generation process.
Should it instead be that we wait until each beam reaches the stop_sequence or any other stopping criteria before stopping the generation process?
Should it instead be that we wait until each beam reaches the stop_sequence or any other stopping criteria before stopping the generation process?
@KMFODA I think eos_token_id
is already handled for beam search, see my comment on the StoppingCriteria
.
I will let others comment on the best way to do this in .generate
but I think we don't need the criteria, just let eos_token_id
regular logic apply (it's handled separately from StoppingCriteria
).
For the tests removing the breakpoint should help then for code quality.
pip install -e .[quality]
make fixup
Should do the trick.
@Narsil @KMFODA I'm in favor of moving it to a StoppingCriteria
, so that all conditions that can terminate generation fall under the same class. However, it should be noted that it is not a requirement to complete the issue, i.e. to add a stop sequence to the text generation pipeline :P
It is already implemented on the multiple generation strategies (e.g. here for greedy search). Also, the existing implementation is different from the current PR -- the existing implementation only checks whether the eos_token
is present in newly generated tokens. This is because models like GPT-2
often set pad_token_id
to eos_token_id
, and we don't want the pad tokens to trigger this condition.
Thanks @Narsil @gante. Okay so for the sake of deploying iteratively I've removed the eos_token_id
from the StoppingCriteria
and will add it as a separate PR.
I've added a test for the stop_sequence
being fed in at the pipeline level. When @Narsil's comment around wether the stop sequence should be handled in the pipeline
or in the generation_kwargs
is addressed I can alter this test accordingly.
We should implement
stop_sequence
only once (probably ingenerate
) but we could have 2 tests if you want to test the full pipeline too. (Probably intests/pipelines/test_pipelines_text_generation.py
for instance.)
If we were to move stop_sequence
to be in generate
wouldn't we have to tokenise it first. In that case what's the reasoning behind feeding it as a stop_sequence
instead of a eos_token_id
?
If we were to move stop_sequence to be in generate wouldn't we have to tokenise it first. In that case what's the reasoning behind feeding it as a stop_sequence instead of a eos_token_id?
You're entirely right, oversight on my part. eos_token_id
already does the job. So we just need to implement stop_sequence
in the pipeline to tokenize the stop_sequence
and produce the eos_token_id
and just feed it to generate.
So no additional code in generate
should be needed actually.
Sorry, failed to see that.
No problem I've just moved the stop_sequence back to the pipeline function and added the tests you requested in the tests/pipelines/test_pipelines_text_generation.py
folder. This should make this PR ready for review now.
When I was playing with the stop_sequence though I found that sometime when I add a specific stop_sequence the output changes and avoids mentioning the word entirely. I don't have live examples now but I just wanted to check if this is normal behaviour? If not I can find examples on public models and share it in a different issue.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@KMFODA I think your PR is almost ready to be merged! Would you like to try to fix the final problems and apply the review suggestions? :-)
Hey @patrickvonplaten. My apologies I was out sick over the past month. I worked on the suggestions now. Hopefully this should be good to merge now but if not let me know!
I'm happy with the PR, except for the EndOfStringCriteria
class -- it is not being used, and it is not a good practice to add unused classes/functions.
@KMFODA can you remove it for now, and perhaps reintroduce it in a follow-up PR (with use cases)? :)
Hi @gante yes of course. I had removed it locally but somehow the changes didn't push through with one of the commits. Forced changed it now. Hopefully that looks good now :).