Nested-LSTM-NLSTM-Pytorch icon indicating copy to clipboard operation
Nested-LSTM-NLSTM-Pytorch copied to clipboard

subtle difference between this implementation and the original paper

Open dugu9sword opened this issue 5 years ago • 1 comments

Hi, Seems that this implementation is different from the model described in the original paper. In line #93:

    cell = torch.cat(((remember_gate * c_cur), (in_gate * cell_gate)),1)

I will refer to (remember_gate * c_cur) and (in_gate * cell_gate) as A and B respectively.

In your code, A and B are concatenated as the input to the nested cell, while in the original paper, B is employed as the hidden state of the nested cell , and A is employed as the input.

In other words, the nested cell is not equivalent to a traditional LSTM cell, since $h_t$ produced by the cell and $h_{t+1}$ fed into the cell are not the same.

Maybe the description in the abstract of the paper is a little confusing, while the main text of the paper and another blog (https://distill.pub/2019/memorization-in-rnns/) have clarified this design.

dugu9sword avatar Sep 11 '19 14:09 dugu9sword

Below is my implementation ,

import torch
from torch import nn
from torch.nn import functional, init
from weight_drop import WeightDrop



class LSTMCell(nn.Module):
    """A basic LSTM cell."""

    def __init__(self, input_size, hidden_size, use_bias=True):
        """
        Most parts are copied from torch.nn.LSTMCell.
        """

        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.use_bias = use_bias
        self.weight_ih = nn.Parameter(
            torch.FloatTensor(input_size, 4 * hidden_size))
        self.weight_hh = nn.Parameter(
            torch.FloatTensor(hidden_size, 4 * hidden_size))
        if use_bias:
            self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        """
        Initialize parameters following the way proposed in the paper.
        """

        init.orthogonal_(self.weight_ih.data)
        weight_hh_data = torch.eye(self.hidden_size)
        weight_hh_data = weight_hh_data.repeat(1, 4)
        with torch.no_grad():
            self.weight_hh.set_(weight_hh_data)

        # import math
        # stdv = 1.0 / math.sqrt(self.hidden_size)
        # for weight in self.parameters():
        #     init.uniform_(weight, -stdv, stdv)

        # The bias is just set to zero vectors.
        if self.use_bias:
            init.constant_(self.bias.data, val=0)



    def forward(self, input_, hx):
        """
        Args:
            input_: A (batch, input_size) tensor containing input
                features.
            hx: A tuple (h_0, c_0), which contains the initial hidden
                and cell state, where the size of both states is
                (batch, hidden_size).
        Returns:
            h_1, c_1: Tensors containing the next hidden and cell state.
        """

        h_0, c_0 = hx
        batch_size = h_0.size(0)
        bias_batch = (self.bias.unsqueeze(0)
                      .expand(batch_size, *self.bias.size()))
        wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
        wi = torch.mm(input_, self.weight_ih)
        f, i, o, g = torch.split(wh_b + wi, self.hidden_size, dim=1)
        c_1 = torch.sigmoid(f) * c_0 + torch.sigmoid(i) * torch.tanh(g)
        h_1 = torch.sigmoid(o) * torch.tanh(c_1)

        # i_gate_num = torch.sum(torch.sigmoid(i)[0]>0.4).item()

        return h_1, c_1

    def __repr__(self):
        s = '{name}({input_size}, {hidden_size})'
        return s.format(name=self.__class__.__name__, **self.__dict__)




class NLSTMCell(nn.Module):

    def __init__(self, input_size, hidden_size, use_bias=True):
        """
        Most parts are copied from torch.nn.LSTMCell.
        """

        super(NLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.use_bias = use_bias
        self.weight_ih = nn.Parameter(
            torch.FloatTensor(input_size, 4 * hidden_size))
        self.weight_hh = nn.Parameter(
            torch.FloatTensor(hidden_size, 4 * hidden_size))
        if use_bias:
            self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size))
        else:
            self.register_parameter('bias', None)
        self.inner_lstm_cell = LSTMCell(hidden_size, hidden_size, True)
        self.reset_parameters()

    def reset_parameters(self):
        """
        Initialize parameters following the way proposed in the paper.
        """

        init.orthogonal_(self.weight_ih.data)
        weight_hh_data = torch.eye(self.hidden_size)
        weight_hh_data = weight_hh_data.repeat(1, 4)
        with torch.no_grad():
            self.weight_hh.set_(weight_hh_data)

        # import math
        # stdv = 1.0 / math.sqrt(self.hidden_size)
        # for weight in self.parameters():
        #     init.uniform_(weight, -stdv, stdv)

        # The bias is just set to zero vectors.
        if self.use_bias:
            init.constant_(self.bias.data, val=0)

        self.inner_lstm_cell.reset_parameters()


    def forward(self, input_, hx):
        """
        Args:
            input_: A (batch, input_size) tensor containing input
                features.
            hx: A tuple (h_0, c_0, c_inner_0), which contains the initial hidden
                and cell state, where the size of both states is
                (batch, hidden_size).
        Returns:
            h_1, c_1, c_inner_1: Tensors containing the next hidden and cell state.
        """

        h_0, c_0, c_inner_0 = hx
        batch_size = h_0.size(0)
        bias_batch = (self.bias.unsqueeze(0)
                      .expand(batch_size, *self.bias.size()))
        wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
        wi = torch.mm(input_, self.weight_ih)
        f, i, o, g = torch.split(wh_b + wi, self.hidden_size, dim=1)

        # c_1 = torch.sigmoid(f) * c_0 + torch.sigmoid(i) * torch.tanh(g)
        inner_hidden = torch.sigmoid(f) * c_0
        inner_input = torch.sigmoid(i) * torch.tanh(g)
        h_inner_1, c_inner_1 = self.inner_lstm_cell(inner_input, (inner_hidden, c_inner_0))

        c_1 = h_inner_1
        h_1 = torch.sigmoid(o) * torch.tanh(c_1)

        # i_gate_num = torch.sum(torch.sigmoid(i)[0]>0.4).item()

        return h_1, c_1, c_inner_1

    def __repr__(self):
        s = '{name}({input_size}, {hidden_size})'
        return s.format(name=self.__class__.__name__, **self.__dict__)






class NLSTM(nn.Module):
    """A module that runs multiple steps of LSTM."""

    def __init__(self, cell_class, input_size, hidden_size, num_layers,
                 use_bias=True, batch_first=False, dropout=0, weight_dropout=0, **kwargs):
        super(NLSTM, self).__init__()
        self.num_layers = num_layers
        self.cell_class = cell_class
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.use_bias = use_bias
        self.batch_first = batch_first
        self.dropout = dropout
        self.weight_drops = None if weight_dropout == 0 else []

        for layer in range(num_layers):
            layer_input_size = input_size if layer == 0 else hidden_size
            cell = cell_class(input_size=layer_input_size,
                              hidden_size=hidden_size,
                              **kwargs)
            setattr(self, 'cell_{}'.format(layer), cell)
        self.dropout_layer = nn.Dropout(dropout)
        self.reset_parameters()
        if self.weight_drops is not None:
            for layer in range(num_layers):
                self.weight_drops.append(WeightDrop(self.get_cell(layer), ['weight_hh'], weight_dropout))

    def get_cell(self, layer):
        return getattr(self, 'cell_{}'.format(layer))

    def reset_parameters(self):
        for layer in range(self.num_layers):
            cell = self.get_cell(layer)
            cell.reset_parameters()

    @staticmethod
    def _forward_rnn(cell, input_, length, hx):
        device = input_.device
        max_time = input_.size(0)
        output = []
        for time in range(max_time):
            h_next, c_next, cin_next = cell(input_=input_[time], hx=hx)
            mask = (time < length).float().unsqueeze(1).expand_as(h_next).to(device)
            h_next = (h_next * mask + hx[0] * (1 - mask)).to(device)
            c_next = (c_next * mask + hx[1] * (1 - mask)).to(device)
            cin_next = (cin_next * mask + hx[1] * (1 - mask)).to(device)
            hx_next = (h_next, c_next, cin_next)
            output.append(h_next)
            hx = hx_next
        output = torch.stack(output, 0)
        return output, hx

    def forward(self, input_, length=None, hx=None):
        glovar.clear()

        if self.batch_first:
            input_ = input_.transpose(0, 1)
        max_time, batch_size, _ = input_.size()
        if length is None:
            length = torch.LongTensor([max_time] * batch_size)
            # if input_.is_cuda:
            #     device = input_.get_device()
            #     length = length.cuda(device)
        if hx is None:
            hx = input_.data.new(batch_size, self.hidden_size).zero_()
            hx = (hx, hx, hx)
        h_n = []
        c_n = []
        cin_n = []
        layer_output = None

        # APPLY WEIGHT DROP
        if self.weight_drops:
            for weight_drop in self.weight_drops:
                weight_drop._setweights()

        for layer in range(self.num_layers):
            cell = self.get_cell(layer)
            layer_output, (layer_h_n, layer_c_n, layer_cin_n) = NLSTM._forward_rnn(
                cell=cell, input_=input_, length=length, hx=hx)
            input_ = self.dropout_layer(layer_output)
            h_n.append(layer_h_n)
            c_n.append(layer_c_n)
            cin_n.append(layer_cin_n)
        output = layer_output
        h_n = torch.stack(h_n, 0)
        c_n = torch.stack(c_n, 0)
        cin_n = torch.stack(cin_n, 0)

        if self.batch_first:
            output = output.transpose(0, 1)
        return output, (h_n, c_n, cin_n)

dugu9sword avatar Sep 11 '19 14:09 dugu9sword