pytorch-seq2seq
pytorch-seq2seq copied to clipboard
getting different shapes for Q,K in Multihead attention values pytorch-seq2seq in DataParallel
In class MultiHeadAttentionLayer(nn.Module):
I am getting these values when printing the shape of Q K and mask, the batch size is differing, I have tried the script through DataParallel, gpus = hps.gpus self.model = nn.DataParallel(self.model, device_ids=gpus, dim=0)
Shape of Q is torch.Size([231, 512, 8, 32])
Shape of K is torch.Size([231, 512, 8, 32])
Shape of mask is torch.Size([231, 1, 1, 512])
Shape of Q is torch.Size([230, 512, 8, 32])
Shape of K is torch.Size([230, 512, 8, 32])
Shape of mask is torch.Size([230, 1, 1, 512])
trying to print before energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
the shapes are listed correctly when I try in google colab