ATFM icon indicating copy to clipboard operation
ATFM copied to clipboard

Applying multiplicative LSTM in ATFM model for short term prediction

Open basharathussain opened this issue 4 years ago • 0 comments

I need your help in mLSTM model incorporation - added following implementation: FIle: # model.layer.conv_lstm

class mLSTMCell(nn.Module): def init(self, input_size, hidden_size, embed_size, output_size): super(mLSTMCell, self).init()

    self.in_channels = in_channels
    self.height = height
    self.width = width 
    self.lstm_channels = lstm_channels

    self.hidden_size = hidden_size
    # input embedding
    self.encoder = nn.Embedding(input_size, embed_size)
    # lstm weights
    self.weight_fm = nn.Linear(hidden_size, hidden_size)
    self.weight_im = nn.Linear(hidden_size, hidden_size)
    self.weight_cm = nn.Linear(hidden_size, hidden_size)
    self.weight_om = nn.Linear(hidden_size, hidden_size)
    self.weight_fx = nn.Linear(embed_size, hidden_size)
    self.weight_ix = nn.Linear(embed_size, hidden_size)
    self.weight_cx = nn.Linear(embed_size, hidden_size)
    self.weight_ox = nn.Linear(embed_size, hidden_size)
    # multiplicative weights
    self.weight_mh = nn.Linear(hidden_size, hidden_size)
    self.weight_mx = nn.Linear(embed_size, hidden_size)
    # decoder
    self.decoder = nn.Linear(hidden_size, output_size)


def forward(self, input, state):
    h_0, c_0 = state
    inp = input
    # encode the input characters
    inp = self.encoder(inp)
    # calculate the multiplicative matrix
    m_t = self.weight_mx(inp) * self.weight_mh(h_0)
    # forget gate
    f_g = F.sigmoid(self.weight_fx(inp) + self.weight_fm(m_t))
    # input gate
    i_g = F.sigmoid(self.weight_ix(inp) + self.weight_im(m_t))
    # output gate
    o_g = F.sigmoid(self.weight_ox(inp) + self.weight_om(m_t))
    # intermediate cell state
    c_tilda = F.tanh(self.weight_cx(inp) + self.weight_cm(m_t))
    # current cell state
    cx = f_g * c_0 + i_g * c_tilda
    # hidden state
    hx = o_g * F.tanh(cx)

    out = self.decoder(hx.view(1,-1))

    #return out, hx, cx
    return hx, cx

def init_hidden(self):
    h_0 = Variable(torch.zeros(1, self.hidden_size)).cuda()
    c_0 = Variable(torch.zeros(1, self.hidden_size)).cuda()
    return h_0, c_0

and change the lstm_layer function as follows:

def lstm_layer(self, inputs): n_in, c_in, h_in, w_in = inputs.size()

    if self.model_type == 'gru':
        state = torch.zeros(n_in, self.lstm_channels, h_in, w_in).cuda()
    elif self.model_type == 'lstm':
        state = (torch.zeros(n_in, self.lstm_channels, h_in, w_in).cuda(),
                 torch.zeros(n_in, self.lstm_channels, h_in, w_in).cuda())
    else: # if self.model_type == 'mlstm':
        state = (torch.zeros(n_in, self.lstm_channels, h_in, w_in).cuda(),
                 torch.zeros(n_in, self.lstm_channels, h_in, w_in).cuda())


    seq = torch.split(inputs, self.in_channels, dim=1)
    hiddent_list = []
    for idx, input in enumerate(seq[::-1]): # using reverse order
        state = self._lstm_cell(input, state)
        if self.model_type == 'gru':
            hidden = state
        elif self.model_type == 'lstm':
            hidden = state[0]
        else: # if self.model_type == 'mlstm':
            hidden = state[0]

        if self.last_conv:
            if self.conv_channels is None:
                raise ValueError('Parameter Out Channel is needed to enable last_conv')
            hidden = self._conv_layer(hidden)

        hiddent_list.append(hidden)
    
    if not self.all_hidden:
        return hiddent_list[-1]
    else:
        hiddent_list.reverse()
        return torch.cat(hiddent_list, 1)

def forward(self, inputs):
    if self.dropout_rate > 0:
        inputs = self._dropout_layer(inputs)
    if self.mode == 'merge':
        output = self.lstm_layer(inputs)
        return output
    elif self.mode == 'cpt':
        if self.cpt is None:
            raise ValueError('Parameter \'cpt\' is required in mode \'cpt\' of ConvLSTM')
        cpt_seq = split_cpt(inputs, self.cpt)
        output_list = [
            self.lstm_layer(input_) for input_ in cpt_seq
        ] 
        output = torch.cat(output_list, 1)
        return output
    else:
        raise('Invalid LSTM mode: '+self.mode)

I am getting error. It is not working. If you help me in this regard to implement mLSTM I will be greatful.

In case you nedd the modified source code or any information, please write in here. thank you. Basharat

basharathussain avatar Aug 15 '20 14:08 basharathussain