cond_rnn
cond_rnn copied to clipboard
pytorch version available?
hi
Your cond_rnn is a great work that I want to use it in my work. Unfortunately, we use pytorch more often. Could you also provide a pytorch version, or give me some hint to implement it?
Best Regards, kingaza
@kingaza Hi great idea. I will try to see how to do it in Pytorch. At the moment, I have no idea how to do it.
Thanks. Actually I have no idea how to implement it in Pytorch, That is why I submit this issue. :) I will also think it over.
Follow up for the availability of PyTorch version as well.
If you want to do it yourself it's very easy. My BiRNN code looks like this:
class myBiRNN(nn.Module):
def __init__(self, input_size, input_size_aux, hidden_size, RNN_type="LSTM", ...)
super(myBiRNN, self).__init__()
self.hidden_size = hidden_size
self.mlp1 = nn.Linear(input_size_aux, hidden_size)
if self.RNN_type=='LSTM':
RNN_model = nn.LSTM
self.mlp2 = nn.Linear(input_size_aux, hidden_size)
elif self.RNN_type=='GRU':
RNN_model = nn.GRU
self.rnn1 = RNN_model(input_size, hidden_size, batch_first=True)
self.rnn2 = RNN_model(hidden_size, hidden_size, batch_first=True)
def forward(self, inputs, inputs_aux)
# inputs are shaped (batch, Nseq, nx), inputs_aux are shaped (batch, nx_aux)
h0 = self.mlp1(inputs_aux)
h0 = nn.Tanh()(h0)
if RNN_type=="LSTM":
c0 = self.mlp2(inputs_aux)
c0 = nn.Tanh()(c0)
initial_states = (h0.view(1,-1,self.hidden_size), c0.view(1,-1, self.hidden_size))
out, (h, c) = self.rnn_forward(inputs, initial_states)
elif RNN_type=="GRU":
initial_states = h0.view(1,-1,self.hidden_size)
out, h = self.rnn_forward(inputs, initial_states)
# Bidirectional RNN, so the second RNN needs the reversed sequence, and its outputs are also reversed
out = torch.flip(out, [1])
...
If you want to do it yourself it's very easy. My BiRNN code looks like this:
class myBiRNN(nn.Module): def __init__(self, input_size, input_size_aux, hidden_size, RNN_type="LSTM", ...) super(myBiRNN, self).__init__() self.hidden_size = hidden_size self.mlp1 = nn.Linear(input_size_aux, hidden_size) if self.RNN_type=='LSTM': RNN_model = nn.LSTM self.mlp2 = nn.Linear(input_size_aux, hidden_size) elif self.RNN_type=='GRU': RNN_model = nn.GRU self.rnn1 = RNN_model(input_size, hidden_size, batch_first=True) self.rnn2 = RNN_model(hidden_size, hidden_size, batch_first=True) def forward(self, inputs, inputs_aux) # inputs are shaped (batch, Nseq, nx), inputs_aux are shaped (batch, nx_aux) h0 = self.mlp1(inputs_aux) h0 = nn.Tanh()(h0) if RNN_type=="LSTM": c0 = self.mlp2(inputs_aux) c0 = nn.Tanh()(c0) initial_states = (h0.view(1,-1,self.hidden_size), c0.view(1,-1, self.hidden_size)) out, (h, c) = self.rnn_forward(inputs, initial_states) elif RNN_type=="GRU": initial_states = h0.view(1,-1,self.hidden_size) out, h = self.rnn_forward(inputs, initial_states) # Bidirectional RNN, so the second RNN needs the reversed sequence, and its outputs are also reversed out = torch.flip(out, [1]) ...
Thanks @peterukk! Looks like just need to initialize the hidden state with the latent representation from the aux variables.
If you want to do it yourself it's very easy. My BiRNN code looks like this:
class myBiRNN(nn.Module): def __init__(self, input_size, input_size_aux, hidden_size, RNN_type="LSTM", ...) super(myBiRNN, self).__init__() self.hidden_size = hidden_size self.mlp1 = nn.Linear(input_size_aux, hidden_size) if self.RNN_type=='LSTM': RNN_model = nn.LSTM self.mlp2 = nn.Linear(input_size_aux, hidden_size) elif self.RNN_type=='GRU': RNN_model = nn.GRU self.rnn1 = RNN_model(input_size, hidden_size, batch_first=True) self.rnn2 = RNN_model(hidden_size, hidden_size, batch_first=True) def forward(self, inputs, inputs_aux) # inputs are shaped (batch, Nseq, nx), inputs_aux are shaped (batch, nx_aux) h0 = self.mlp1(inputs_aux) h0 = nn.Tanh()(h0) if RNN_type=="LSTM": c0 = self.mlp2(inputs_aux) c0 = nn.Tanh()(c0) initial_states = (h0.view(1,-1,self.hidden_size), c0.view(1,-1, self.hidden_size)) out, (h, c) = self.rnn_forward(inputs, initial_states) elif RNN_type=="GRU": initial_states = h0.view(1,-1,self.hidden_size) out, h = self.rnn_forward(inputs, initial_states) # Bidirectional RNN, so the second RNN needs the reversed sequence, and its outputs are also reversed out = torch.flip(out, [1]) ...
One questions about the backward LSTM. Do you think it's a good idea to initialize the states similar as the forward one?
I would do whatever "makes sense". If the auxiliary information is related to the start of the sequence I would only initialize the forward one. If there is no such clear relation to one particular end of the sequence I would use it to initialize both.