practical-pytorch
practical-pytorch copied to clipboard
element-wise assignment in attention weight computing might be slow
for b in range(this_batch_size): # Calculate energy for each encoder output for i in range(max_len): attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))
A better way to handle it is just to use tensor manipulation.
H = hidden.repeat(max_len,1,1).transpose(0,1) encoder_outputs = encoder_outputs.transpose(0,1) # [B*T*H] attn_energies = self.score(H,encoder_outputs) # compute attention score
along with some modify in self.score implementation.
The whole implementation has been posted in my repo.
Thanks!
I encountered the same issue, i see your repo only implementation for concat. my code as follows, it's support three method but not elegantly.
class Attn(nn.Module):
def __init__(self, method, hidden_size):
super(Attn, self).__init__()
self.method = method
self.hidden_size = hidden_size
if self.method == 'general':
self.attn = nn.Linear(self.hidden_size, self.hidden_size)
elif self.method == 'concat':
self.attn = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.v = nn.Parameter(weight_init.xavier_uniform(torch.FloatTensor(1, self.hidden_size)))
def forward(self, hidden, encoder_outputs):
attn_energies = self.batch_score(hidden, encoder_outputs)
return F.softmax(attn_energies).unsqueeze(1)
def batch_score(self, hidden, encoder_outputs):
if self.method == 'dot':
encoder_outputs = encoder_outputs.permute(1, 2, 0)
energy = torch.bmm(hidden.transpose(0, 1), encoder_outputs).squeeze(1)
elif self.method == 'general':
length = encoder_outputs.size(0)
batch_size = encoder_outputs.size(1)
energy = self.attn(encoder_outputs.view(-1, self.hidden_size)).view(length, batch_size, self.hidden_size)
energy = torch.bmm(hidden.transpose(0, 1), energy.permute(1, 2, 0)).squeeze(1)
elif self.method == 'concat':
length = encoder_outputs.size(0)
batch_size = encoder_outputs.size(1)
attn_input = torch.cat((hidden.repeat(length, 1, 1), encoder_outputs), dim=2)
energy = self.attn(attn_input.view(-1, 2 * self.hidden_size)).view(length, batch_size, self.hidden_size)
energy = torch.bmm(self.v.repeat(batch_size, 1, 1), energy.permute(1, 2, 0)).squeeze(1)
return energy
@czs0x55aa Thanks a lot,i am trainning a textsum model with input in 32X120. and your code help me get out of this problem.
Hi @czs0x55aa ,
Thanks for the faster version, but I get the following error:
Traceback (most recent call last):
File "main.py", line 244, in <module>
trainIters(encoder1, decoder1, training_pairs)
File "main.py", line 142, in trainIters
encoder_optimizer, decoder_optimizer, criterion)
File "main.py", line 69, in train
decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
File "/home/magnet/bsrivast/asr/venv/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/home/magnet/bsrivast/asr/src/encoder_decoder.py", line 148, in forward
attn_weights = self.attn(rnn_output.squeeze(0), encoder_outputs)
File "/home/magnet/bsrivast/asr/venv/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/home/magnet/bsrivast/asr/src/encoder_decoder.py", line 77, in forward
attn_energies = self.batch_score(hidden, encoder_outputs)
File "/home/magnet/bsrivast/asr/src/encoder_decoder.py", line 107, in batch_score
energy = torch.bmm(hidden.transpose(0, 1), energy.permute(1, 2, 0)).squeeze(1)
RuntimeError: invalid argument 6: expected 3D tensor at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:466
I tried to unsqueeze(0) the hidden
variable but it throws wrong dimension error. Could you please help?
Thanks -Brij