transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Generation: stop at `eos` for assisted decoding

Open zucchini-nlp opened this issue 1 year ago • 4 comments

What does this PR do?

Addresses the issue in this comment.

The problem was that assisted generation is flaky when it comes to EOS tokens. If the assistant generates target model's eos token in the middle of candidate sequence and it get accepted by the target, we don't stop generating because the criteria checks only the last token. This happens mostly for Prompt-Lookup decoding because it's simple copy-paste from history.

One possible solution is to check eos in all of the newly generated input ids, another is to crop candidates if there's eos. I went the second way.

zucchini-nlp avatar Jun 07 '24 05:06 zucchini-nlp

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jul 07 '24 08:07 github-actions[bot]

Latest release doesn't have a fix for this, it's still an issue

Ednaordinary avatar Jul 07 '24 08:07 Ednaordinary

@gante okay, moved the logic into PLD. The main concern was that assistant model can have it's eos token to be different from target model, so I added a line in AssistedGenerator to set eos token.

Added a test for PLD, but not for assistant model because we can't control what the assistant model generates. Also tested with the case from the linked issue, is working on my end

zucchini-nlp avatar Jul 10 '24 06:07 zucchini-nlp

Please note this will error out if the eos_token_id is a list (like in some llama3 configurations) image

Ednaordinary avatar Jul 23 '24 23:07 Ednaordinary

@Ednaordinary are you using the generate() method? I don't see and error using current branch, because internally we convert all special tokens to tensors and self.eos_id cannot be a list.

I guess you're using main and applied current changes above, in that case it will fail because I didnt yet rebase. There were some changes for special tokens lately afaik. I will make sure it works before merging

zucchini-nlp avatar Jul 24 '24 05:07 zucchini-nlp

In some configurations running llama, the eos_token_id is set to a list. image this works fine for sampling until prompt lookup decoding is added. It's a simple fix, though, adding [0] to the end of tokenizer.encode("<|eot_id|>"). Works fine after applying this PR to current main I mainly raise this concern because llama 3.1 has three stop tokens: image

Ednaordinary avatar Jul 24 '24 06:07 Ednaordinary

@Ednaordinary I see, you're using llama-3.1 that forces you to use the latest version. As I said, the problem was that current branch is behind main for months. Now I rebased and everything should be fine

zucchini-nlp avatar Jul 24 '24 07:07 zucchini-nlp