practical-pytorch icon indicating copy to clipboard operation
practical-pytorch copied to clipboard

The attention mechanism is not the original attention mechanism in the paper

Open rk2900 opened this issue 6 years ago • 7 comments

In the tutorial, I find that the "attention" mechanism is a fake attention, since the calculated attention weights have no relationship with the encoder output vectors. The implementation in the original attention paper "Neural Machine Translation By Jointly Learning To Align and Translate" assumes that each attention weight e_{ij} is calculated by incorporating both decoder hidden vector s_{i-1} and h_j and both weight alpha_{ij} is calculated using all the encoder outputs h_j. All the discussion is based on the equation on Page 14 in the paper.

But in your implementation, the attention weights are calculated using the decoder information, where is expressed in the code:

attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)

I think the attn_weights should be calculated as that in the paper. So I propose another implementation as that

s = torch.cat( (hidden[0].repeat(self.max_length,1), encoder_outputs), dim=1 )
ns = self.attn2(F.tanh(self.attn1(s)))
attn_weights = F.softmax(ns, dim=0).view(1, -1) # (1 * max_len)

The original implementation in this tutorial has performance even worse than simple encoder-decoder. While my implementation has much better performance than the other two.

Your implementation:

2m 1s (- 28m 25s) (5000 6%) 2.8932
3m 56s (- 25m 40s) (10000 13%) 2.3282
5m 51s (- 23m 25s) (15000 20%) 1.9983
7m 47s (- 21m 26s) (20000 26%) 1.8198
9m 42s (- 19m 24s) (25000 33%) 1.6083
11m 36s (- 17m 24s) (30000 40%) 1.4259
13m 33s (- 15m 30s) (35000 46%) 1.2922
15m 29s (- 13m 33s) (40000 53%) 1.1810
17m 28s (- 11m 39s) (45000 60%) 1.0585
19m 22s (- 9m 41s) (50000 66%) 0.9613
21m 15s (- 7m 43s) (55000 73%) 0.8629
23m 12s (- 5m 48s) (60000 80%) 0.7923
25m 6s (- 3m 51s) (65000 86%) 0.7210
27m 9s (- 1m 56s) (70000 93%) 0.6777
29m 9s (- 0m 0s) (75000 100%) 0.6167

My implementation:

2m 23s (- 33m 24s) (5000 6%) 2.9327
4m 22s (- 28m 23s) (10000 13%) 2.3732
6m 25s (- 25m 43s) (15000 20%) 2.0273
8m 34s (- 23m 35s) (20000 26%) 1.7628
10m 46s (- 21m 32s) (25000 33%) 1.5793
13m 16s (- 19m 54s) (30000 40%) 1.3765
15m 48s (- 18m 4s) (35000 46%) 1.2604
17m 52s (- 15m 38s) (40000 53%) 1.0906
20m 15s (- 13m 30s) (45000 60%) 0.9862
22m 20s (- 11m 10s) (50000 66%) 0.8890
24m 35s (- 8m 56s) (55000 73%) 0.7877
26m 33s (- 6m 38s) (60000 80%) 0.7204
28m 48s (- 4m 25s) (65000 86%) 0.6471
31m 13s (- 2m 13s) (70000 93%) 0.5796
33m 33s (- 0m 0s) (75000 100%) 0.5342

The results above are based on one 1080 GPU.

rk2900 avatar Dec 31 '17 08:12 rk2900

Are you comparing to the latest implementation, at https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb ?

spro avatar Jan 01 '18 22:01 spro

The implementation in the above link seems to be correct. It would be nice if this tutorial could be updated as it is still using a non-standard implementation of attention.

This caused me significant confusion when I first started playing with encoder-decoders.

adamklec avatar Apr 09 '18 20:04 adamklec

@adamklec, do you know where that non-standard attent is from? Thanks.

zyxue avatar May 28 '18 17:05 zyxue

Seconded. I spent hours trying to make sense of the (tutorial)[https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html].

Moreover, I think I still need help understanding the code over in https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb. You use self.attn = GeneralAttn(hidden_size) in class BahdanauAttnDecoderRNN.

GeneralAttn is not defined anywhere. In (issue 23)[https://github.com/spro/practical-pytorch/issues/23], you mention that this is supposed to be Attn class. However, as per my understanding, neither dot, general or concat corresponds to the ones mentioned in Bahdanau et. al. But instead, they're the ones mentioned in Luong et. al. I'm not completely sure about this claim though, so I would love if someone could clarify this

geraltofrivia avatar Jul 01 '18 17:07 geraltofrivia

Hi @geraltofrivia, I found the same issues as you. I don't understand why @spro do not fix this problem as it is so obvious confusing.

NiceMartin avatar Oct 11 '18 03:10 NiceMartin

This issue also mentioned this. @spro Please fix it quickly. This mistake in tutorial has been existing for 2 years.

soloice avatar May 08 '19 06:05 soloice

@soloice been four years now.... just spend a whole afternoon learning the wrong version of code...

Adam-fei avatar Aug 05 '21 11:08 Adam-fei