practical-pytorch icon indicating copy to clipboard operation
practical-pytorch copied to clipboard

element-wise assignment in attention weight computing might be slow

Open AuCson opened this issue 7 years ago • 3 comments

for b in range(this_batch_size): # Calculate energy for each encoder output for i in range(max_len): attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))

A better way to handle it is just to use tensor manipulation. H = hidden.repeat(max_len,1,1).transpose(0,1) encoder_outputs = encoder_outputs.transpose(0,1) # [B*T*H] attn_energies = self.score(H,encoder_outputs) # compute attention score along with some modify in self.score implementation. The whole implementation has been posted in my repo. Thanks!

AuCson avatar Sep 04 '17 15:09 AuCson

I encountered the same issue, i see your repo only implementation for concat. my code as follows, it's support three method but not elegantly.

class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()

        self.method = method
        self.hidden_size = hidden_size

        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, self.hidden_size)
        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, self.hidden_size)
            self.v = nn.Parameter(weight_init.xavier_uniform(torch.FloatTensor(1, self.hidden_size)))

    def forward(self, hidden, encoder_outputs):
        attn_energies = self.batch_score(hidden, encoder_outputs)
        return F.softmax(attn_energies).unsqueeze(1)

    def batch_score(self, hidden, encoder_outputs):
        if self.method == 'dot':
            encoder_outputs = encoder_outputs.permute(1, 2, 0)
            energy = torch.bmm(hidden.transpose(0, 1), encoder_outputs).squeeze(1)
        elif self.method == 'general':
            length = encoder_outputs.size(0)
            batch_size = encoder_outputs.size(1)
            energy = self.attn(encoder_outputs.view(-1, self.hidden_size)).view(length, batch_size, self.hidden_size)
            energy = torch.bmm(hidden.transpose(0, 1), energy.permute(1, 2, 0)).squeeze(1)
        elif self.method == 'concat':
            length = encoder_outputs.size(0)
            batch_size = encoder_outputs.size(1)
            attn_input = torch.cat((hidden.repeat(length, 1, 1), encoder_outputs), dim=2)
            energy = self.attn(attn_input.view(-1, 2 * self.hidden_size)).view(length, batch_size, self.hidden_size)
            energy = torch.bmm(self.v.repeat(batch_size, 1, 1), energy.permute(1, 2, 0)).squeeze(1)
        return energy

czs0x55aa avatar Sep 08 '17 03:09 czs0x55aa

@czs0x55aa Thanks a lot,i am trainning a textsum model with input in 32X120. and your code help me get out of this problem.

sys1874 avatar Mar 14 '18 09:03 sys1874

Hi @czs0x55aa ,

Thanks for the faster version, but I get the following error:

Traceback (most recent call last):
  File "main.py", line 244, in <module>
    trainIters(encoder1, decoder1, training_pairs)
  File "main.py", line 142, in trainIters
    encoder_optimizer, decoder_optimizer, criterion)
  File "main.py", line 69, in train
    decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
  File "/home/magnet/bsrivast/asr/venv/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/magnet/bsrivast/asr/src/encoder_decoder.py", line 148, in forward
    attn_weights = self.attn(rnn_output.squeeze(0), encoder_outputs)
  File "/home/magnet/bsrivast/asr/venv/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/magnet/bsrivast/asr/src/encoder_decoder.py", line 77, in forward
    attn_energies = self.batch_score(hidden, encoder_outputs)
  File "/home/magnet/bsrivast/asr/src/encoder_decoder.py", line 107, in batch_score
    energy = torch.bmm(hidden.transpose(0, 1), energy.permute(1, 2, 0)).squeeze(1)
RuntimeError: invalid argument 6: expected 3D tensor at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:466

I tried to unsqueeze(0) the hidden variable but it throws wrong dimension error. Could you please help?

Thanks -Brij

brijmohan avatar Nov 13 '18 15:11 brijmohan