maxtext
maxtext copied to clipboard
Support beam search
Hi,
It would be nice to support beam search.
There is the reference flax implementation in wmt example and the equivalent one from transformers
.
I am guessing that we could:
- duplicate inputs per num_beams initially
- at each step we do:
- decode_step
- select top beams
- overwrite entire past cache per selected beams
- update cache with new selected tokens
So maybe the extra step here is to add the "overwrite entire past cache per selected beams"? Curious if you have suggestions for implementation