pytorch2keras
pytorch2keras copied to clipboard
AttributeError: Can't gather from tf tensor.
Hi! Thanks for implementing this necessary conversion 👍 I am stuck with error 'AttributeError: Can't gather from tf tensor.' when I'm trying to export my PyTorch model to Keras.
'k_model = pytorch_to_keras(model, input_var, (1,None), verbose=True, name_policy='short', change_ordering=True)'
PyTorch model:
class SentenceEmbeddingsModel(nn.Module):
def __init__(self,
vocab_size,embedding_dim,
max_length=40,
word_vectors=None,
device=device,
C=0.001,d_a=10,r_a=4,
hidden_size=100):
super(SentenceEmbeddingsModel, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.d_a = d_a
self.C = C
self.r_a = r_a
self.rnn_hidden_size = hidden_size
w =torch.FloatTensor(word_vectors)
self.embeddings = self.embeddings.from_pretrained(w)
self.embeddings.weight.requires_grad = False
ws_d = embedding_dim
self.ws1 = nn.Parameter(torch.FloatTensor(1, self.d_a, ws_d))
nn.init.xavier_uniform_(self.ws1)
self.ws1.requires_grad = True
self.ws2 = nn.Parameter(torch.FloatTensor(1, self.r_a, self.d_a))
nn.init.xavier_uniform_(self.ws2)
self.ws2.requires_grad = True
self.dropout1 = nn.Dropout(0.1)
self.device = device
self.dense = nn.Sequential(
nn.Linear(ws_d, 20, bias=True),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(20, num_classes, bias=True),
)
self.linear = nn.Linear(ws_d * self.r_a, ws_d)
def forward(self, inputs):
e = self.embeddings(inputs)
mask = (inputs != 0)[:, :, None].float().to(self.device)
masked = e.mul(mask)
r = self.dropout1(masked)
z = r
a1 =torch.tanh(self.ws1.matmul(z.transpose(dim0=1, dim1=2)))
attention = F.softmax(self.ws2.matmul(a1), dim=2) # n_batch - r_a - max_lentgh
m = attention.matmul(z) # n_batch - r_a - ws_d
# here we get r_a * ws_d embedding matrix per sentence
flatten = m.view(z.shape[0], -1, 1)[:, :, 0]
m = self.linear(flatten)
out = torch.sigmoid(self.dense(m))
return out
```
Is anyone able to help me?
Thanks
any update on this error?
Hello,
Is there any update on this issue ?
facing this issue, any updates?