cond_rnn icon indicating copy to clipboard operation
cond_rnn copied to clipboard

pytorch version available?

Open kingaza opened this issue 1 year ago • 7 comments

hi

Your cond_rnn is a great work that I want to use it in my work. Unfortunately, we use pytorch more often. Could you also provide a pytorch version, or give me some hint to implement it?

Best Regards, kingaza

kingaza avatar Aug 24 '23 10:08 kingaza

@kingaza Hi great idea. I will try to see how to do it in Pytorch. At the moment, I have no idea how to do it.

philipperemy avatar Aug 29 '23 02:08 philipperemy

Thanks. Actually I have no idea how to implement it in Pytorch, That is why I submit this issue. :) I will also think it over.

kingaza avatar Aug 29 '23 06:08 kingaza

Follow up for the availability of PyTorch version as well.

XiaohuiZhang1996 avatar Jul 23 '24 22:07 XiaohuiZhang1996

If you want to do it yourself it's very easy. My BiRNN code looks like this:

class myBiRNN(nn.Module):
	def __init__(self, input_size, input_size_aux, hidden_size, RNN_type="LSTM", ...)
		super(myBiRNN, self).__init__()
		self.hidden_size = hidden_size
		self.mlp1 = nn.Linear(input_size_aux, hidden_size)
		
		if self.RNN_type=='LSTM':
		    RNN_model = nn.LSTM
		    self.mlp2 = nn.Linear(input_size_aux, hidden_size)
		elif self.RNN_type=='GRU':
		    RNN_model = nn.GRU
		    
		self.rnn1      = RNN_model(input_size, hidden_size, batch_first=True)
		self.rnn2      = RNN_model(hidden_size, hidden_size, batch_first=True)
		
	def forward(self, inputs, inputs_aux)
		# inputs are shaped (batch, Nseq, nx), inputs_aux are shaped (batch, nx_aux)

		h0 = self.mlp1(inputs_aux)
		h0 = nn.Tanh()(h0)

		if RNN_type=="LSTM":
			c0 = self.mlp2(inputs_aux)
			c0 = nn.Tanh()(c0)
			initial_states = (h0.view(1,-1,self.hidden_size), c0.view(1,-1, self.hidden_size)) 
			out, (h, c) = self.rnn_forward(inputs, initial_states)
		elif RNN_type=="GRU":
			initial_states = h0.view(1,-1,self.hidden_size)
			out, h      = self.rnn_forward(inputs, initial_states)
			
		 # Bidirectional RNN, so the second RNN needs the reversed sequence, and its outputs are also reversed
		out = torch.flip(out, [1])
		...

peterukk avatar Jul 24 '24 11:07 peterukk

If you want to do it yourself it's very easy. My BiRNN code looks like this:

class myBiRNN(nn.Module):
	def __init__(self, input_size, input_size_aux, hidden_size, RNN_type="LSTM", ...)
		super(myBiRNN, self).__init__()
		self.hidden_size = hidden_size
		self.mlp1 = nn.Linear(input_size_aux, hidden_size)
		
		if self.RNN_type=='LSTM':
		    RNN_model = nn.LSTM
		    self.mlp2 = nn.Linear(input_size_aux, hidden_size)
		elif self.RNN_type=='GRU':
		    RNN_model = nn.GRU
		    
		self.rnn1      = RNN_model(input_size, hidden_size, batch_first=True)
		self.rnn2      = RNN_model(hidden_size, hidden_size, batch_first=True)
		
	def forward(self, inputs, inputs_aux)
		# inputs are shaped (batch, Nseq, nx), inputs_aux are shaped (batch, nx_aux)

		h0 = self.mlp1(inputs_aux)
		h0 = nn.Tanh()(h0)

		if RNN_type=="LSTM":
			c0 = self.mlp2(inputs_aux)
			c0 = nn.Tanh()(c0)
			initial_states = (h0.view(1,-1,self.hidden_size), c0.view(1,-1, self.hidden_size)) 
			out, (h, c) = self.rnn_forward(inputs, initial_states)
		elif RNN_type=="GRU":
			initial_states = h0.view(1,-1,self.hidden_size)
			out, h      = self.rnn_forward(inputs, initial_states)
			
		 # Bidirectional RNN, so the second RNN needs the reversed sequence, and its outputs are also reversed
		out = torch.flip(out, [1])
		...

Thanks @peterukk! Looks like just need to initialize the hidden state with the latent representation from the aux variables.

XiaohuiZhang1996 avatar Jul 24 '24 14:07 XiaohuiZhang1996

If you want to do it yourself it's very easy. My BiRNN code looks like this:

class myBiRNN(nn.Module):
	def __init__(self, input_size, input_size_aux, hidden_size, RNN_type="LSTM", ...)
		super(myBiRNN, self).__init__()
		self.hidden_size = hidden_size
		self.mlp1 = nn.Linear(input_size_aux, hidden_size)
		
		if self.RNN_type=='LSTM':
		    RNN_model = nn.LSTM
		    self.mlp2 = nn.Linear(input_size_aux, hidden_size)
		elif self.RNN_type=='GRU':
		    RNN_model = nn.GRU
		    
		self.rnn1      = RNN_model(input_size, hidden_size, batch_first=True)
		self.rnn2      = RNN_model(hidden_size, hidden_size, batch_first=True)
		
	def forward(self, inputs, inputs_aux)
		# inputs are shaped (batch, Nseq, nx), inputs_aux are shaped (batch, nx_aux)

		h0 = self.mlp1(inputs_aux)
		h0 = nn.Tanh()(h0)

		if RNN_type=="LSTM":
			c0 = self.mlp2(inputs_aux)
			c0 = nn.Tanh()(c0)
			initial_states = (h0.view(1,-1,self.hidden_size), c0.view(1,-1, self.hidden_size)) 
			out, (h, c) = self.rnn_forward(inputs, initial_states)
		elif RNN_type=="GRU":
			initial_states = h0.view(1,-1,self.hidden_size)
			out, h      = self.rnn_forward(inputs, initial_states)
			
		 # Bidirectional RNN, so the second RNN needs the reversed sequence, and its outputs are also reversed
		out = torch.flip(out, [1])
		...

One questions about the backward LSTM. Do you think it's a good idea to initialize the states similar as the forward one?

XiaohuiZhang1996 avatar Jul 25 '24 16:07 XiaohuiZhang1996

I would do whatever "makes sense". If the auxiliary information is related to the start of the sequence I would only initialize the forward one. If there is no such clear relation to one particular end of the sequence I would use it to initialize both.

peterukk avatar Jul 26 '24 16:07 peterukk