Using tfa.seq2seq.BeamSearchDecoder with transformer architecture.
Describe the feature and the current behavior/state. I am new to deep learning and tensorflow and recently I created a model similar to the transformer model tutorial. Although the tutorial used a greedy decoder, I wanted to use beam search with the model. I came across this tool, but the problem is this decoder only works with RNN architecture. Can I use this method for the transformer model too? If so, a tutorial on how to use this tool will be really helpful. If not, I request a similar tool for transformer models also.
Relevant information
- Are you willing to contribute it (yes/no): no
- Are you willing to maintain it going forward? (yes/no): no
- Is there a relevant academic paper? (if so, where): Attention is all you need
- Is there already an implementation in another framework? (if so, where): I guess tensor2tensor had this implemented, but I'm not experienced enough to understand the code.
Who will benefit with this feature? Beam search decoding is a hard thing to grasp. All researchers who do not want to spend time in decoding part will surely benefit from this.
The cell argument of the tfa.seq2seq.BeamSearchDecoder constructor should be a layer with the same interface as a tf.keras.layers.AbstractRNNCell. It should be possible to adapt or wrap any decoder (including the Transformer decoder) to implement this interface. Can you look into that?
You could also look at the dynamic_decode function from OpenNMT-tf which can run beam search and just requires a callable function. It may be easier than implementing a class interface. (Full disclosure: I'm the maintainer of OpenNMT-tf.)
Hi Guillaume, What should I use in Tensorflow 1.x? Thanks.
Tensorflow 1.x is End Of Life