trax
trax copied to clipboard
Relative position attention
I implemented 2D relative attention in T5 as a bias term that is added before softmax. I would like to do the same for ReformerLM and later also to an encoder-decoder Reformer with LSH attention. It helped a lot in case of T5 that the creator pointed me to the right place in the code. Could someone recommend the best place to implement this? My first guess would be here: https://github.com/google/trax/blob/9483a5a1221ea61d36ce2c013cb08eb029d8f843/trax/layers/research/efficient_attention.py#L251-L253 Any other place I should also add this? Or anything I should be aware of?
That sounds very good! Another possible place would be here, in the main attention file: https://github.com/google/trax/blob/master/trax/layers/attention.py#L205
Could you first quickly explain your method? I'm not sure I understand 100% if all the arguments you need are there?
For the basic relative attention scenario I add 1 of N (for T5 N=32) different, learned scalars to each query-key dot product based on the relative distance of the corresponding tokens in the sequence. When I also consider the tokens' positions on the page (I am working with PDF files) I also add 2 additional scalars for relative X and Y positions. So my input will also contain 2 additional tensors with the same shape as the tokenized text input containing floats representing coordinates.