maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Support beam search

Open borisdayma opened this issue 10 months ago • 0 comments

Hi,

It would be nice to support beam search.

There is the reference flax implementation in wmt example and the equivalent one from transformers.

I am guessing that we could:

  • duplicate inputs per num_beams initially
  • at each step we do:
    • decode_step
    • select top beams
    • overwrite entire past cache per selected beams
    • update cache with new selected tokens

So maybe the extra step here is to add the "overwrite entire past cache per selected beams"? Curious if you have suggestions for implementation

borisdayma avatar Apr 15 '24 18:04 borisdayma