ConvLSTM_pytorch
ConvLSTM_pytorch copied to clipboard
Does the stateful implementataion work similar to the LSTMs statefulness.
Thank you for the implementation, can we use this implementation to train video classifiers where the input sequences have variable length. Also can we use the statefulness of the network to infer on single frames ,(by utilizing the states from the previous input) similar to the LSTM implementation. Currently I use the following definition of ConvLSTM where the input sequence length needs to be defined. If the testing happens on a variable sequence length, then matrices are impacted (classification worsens):
import torch from torch import nn import torch.nn.functional as f from torch.autograd import Variable
class ConvLSTMCell(nn.Module): def init(self, input_size, hidden_size, kernel_size=3, stride=1, padding=1): super(ConvLSTMCell, self).init() self.input_size = input_size self.hidden_size = hidden_size self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size=kernel_size, stride=stride, padding=padding) torch.nn.init.xavier_normal_(self.Gates.weight) torch.nn.init.constant_(self.Gates.bias, 0)
def forward(self, input_, prev_state):
batch_size = input_.data.size()[0]
spatial_size = input_.data.size()[2:]
if prev_state is None:
state_size = [batch_size, self.hidden_size] + list(spatial_size)
prev_state = (Variable(torch.zeros(state_size).cuda()), Variable(torch.zeros(state_size).cuda()))
prev_hidden, prev_cell = prev_state
stacked_inputs = torch.cat((input_, prev_hidden), 1)
gates = self.Gates(stacked_inputs)
in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)
in_gate = torch.sigmoid(in_gate)
remember_gate = torch.sigmoid(remember_gate)
out_gate = torch.sigmoid(out_gate)
cell_gate = torch.tanh(cell_gate)
cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
hidden = out_gate * torch.tanh(cell)
return hidden, cell
and initiate the layer by: conv_lstm = ConvLSTMCell(input_size, hidden_mem_size)
But I am stuck with fixed input sequences sizes and statelessness, does your implementation overcome these problems?