Nested-LSTM-NLSTM-Pytorch
Nested-LSTM-NLSTM-Pytorch copied to clipboard
subtle difference between this implementation and the original paper
Hi, Seems that this implementation is different from the model described in the original paper. In line #93:
cell = torch.cat(((remember_gate * c_cur), (in_gate * cell_gate)),1)
I will refer to (remember_gate * c_cur)
and (in_gate * cell_gate)
as A
and B
respectively.
In your code, A
and B
are concatenated as the input to the nested cell, while in the original paper, B
is employed as the hidden state of the nested cell , and A
is employed as the input.
In other words, the nested cell is not equivalent to a traditional LSTM cell, since $h_t$ produced by the cell and $h_{t+1}$ fed into the cell are not the same.
Maybe the description in the abstract of the paper is a little confusing, while the main text of the paper and another blog (https://distill.pub/2019/memorization-in-rnns/) have clarified this design.
Below is my implementation ,
import torch
from torch import nn
from torch.nn import functional, init
from weight_drop import WeightDrop
class LSTMCell(nn.Module):
"""A basic LSTM cell."""
def __init__(self, input_size, hidden_size, use_bias=True):
"""
Most parts are copied from torch.nn.LSTMCell.
"""
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.use_bias = use_bias
self.weight_ih = nn.Parameter(
torch.FloatTensor(input_size, 4 * hidden_size))
self.weight_hh = nn.Parameter(
torch.FloatTensor(hidden_size, 4 * hidden_size))
if use_bias:
self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""
Initialize parameters following the way proposed in the paper.
"""
init.orthogonal_(self.weight_ih.data)
weight_hh_data = torch.eye(self.hidden_size)
weight_hh_data = weight_hh_data.repeat(1, 4)
with torch.no_grad():
self.weight_hh.set_(weight_hh_data)
# import math
# stdv = 1.0 / math.sqrt(self.hidden_size)
# for weight in self.parameters():
# init.uniform_(weight, -stdv, stdv)
# The bias is just set to zero vectors.
if self.use_bias:
init.constant_(self.bias.data, val=0)
def forward(self, input_, hx):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: A tuple (h_0, c_0), which contains the initial hidden
and cell state, where the size of both states is
(batch, hidden_size).
Returns:
h_1, c_1: Tensors containing the next hidden and cell state.
"""
h_0, c_0 = hx
batch_size = h_0.size(0)
bias_batch = (self.bias.unsqueeze(0)
.expand(batch_size, *self.bias.size()))
wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
wi = torch.mm(input_, self.weight_ih)
f, i, o, g = torch.split(wh_b + wi, self.hidden_size, dim=1)
c_1 = torch.sigmoid(f) * c_0 + torch.sigmoid(i) * torch.tanh(g)
h_1 = torch.sigmoid(o) * torch.tanh(c_1)
# i_gate_num = torch.sum(torch.sigmoid(i)[0]>0.4).item()
return h_1, c_1
def __repr__(self):
s = '{name}({input_size}, {hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__)
class NLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, use_bias=True):
"""
Most parts are copied from torch.nn.LSTMCell.
"""
super(NLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.use_bias = use_bias
self.weight_ih = nn.Parameter(
torch.FloatTensor(input_size, 4 * hidden_size))
self.weight_hh = nn.Parameter(
torch.FloatTensor(hidden_size, 4 * hidden_size))
if use_bias:
self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size))
else:
self.register_parameter('bias', None)
self.inner_lstm_cell = LSTMCell(hidden_size, hidden_size, True)
self.reset_parameters()
def reset_parameters(self):
"""
Initialize parameters following the way proposed in the paper.
"""
init.orthogonal_(self.weight_ih.data)
weight_hh_data = torch.eye(self.hidden_size)
weight_hh_data = weight_hh_data.repeat(1, 4)
with torch.no_grad():
self.weight_hh.set_(weight_hh_data)
# import math
# stdv = 1.0 / math.sqrt(self.hidden_size)
# for weight in self.parameters():
# init.uniform_(weight, -stdv, stdv)
# The bias is just set to zero vectors.
if self.use_bias:
init.constant_(self.bias.data, val=0)
self.inner_lstm_cell.reset_parameters()
def forward(self, input_, hx):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: A tuple (h_0, c_0, c_inner_0), which contains the initial hidden
and cell state, where the size of both states is
(batch, hidden_size).
Returns:
h_1, c_1, c_inner_1: Tensors containing the next hidden and cell state.
"""
h_0, c_0, c_inner_0 = hx
batch_size = h_0.size(0)
bias_batch = (self.bias.unsqueeze(0)
.expand(batch_size, *self.bias.size()))
wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
wi = torch.mm(input_, self.weight_ih)
f, i, o, g = torch.split(wh_b + wi, self.hidden_size, dim=1)
# c_1 = torch.sigmoid(f) * c_0 + torch.sigmoid(i) * torch.tanh(g)
inner_hidden = torch.sigmoid(f) * c_0
inner_input = torch.sigmoid(i) * torch.tanh(g)
h_inner_1, c_inner_1 = self.inner_lstm_cell(inner_input, (inner_hidden, c_inner_0))
c_1 = h_inner_1
h_1 = torch.sigmoid(o) * torch.tanh(c_1)
# i_gate_num = torch.sum(torch.sigmoid(i)[0]>0.4).item()
return h_1, c_1, c_inner_1
def __repr__(self):
s = '{name}({input_size}, {hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__)
class NLSTM(nn.Module):
"""A module that runs multiple steps of LSTM."""
def __init__(self, cell_class, input_size, hidden_size, num_layers,
use_bias=True, batch_first=False, dropout=0, weight_dropout=0, **kwargs):
super(NLSTM, self).__init__()
self.num_layers = num_layers
self.cell_class = cell_class
self.input_size = input_size
self.hidden_size = hidden_size
self.use_bias = use_bias
self.batch_first = batch_first
self.dropout = dropout
self.weight_drops = None if weight_dropout == 0 else []
for layer in range(num_layers):
layer_input_size = input_size if layer == 0 else hidden_size
cell = cell_class(input_size=layer_input_size,
hidden_size=hidden_size,
**kwargs)
setattr(self, 'cell_{}'.format(layer), cell)
self.dropout_layer = nn.Dropout(dropout)
self.reset_parameters()
if self.weight_drops is not None:
for layer in range(num_layers):
self.weight_drops.append(WeightDrop(self.get_cell(layer), ['weight_hh'], weight_dropout))
def get_cell(self, layer):
return getattr(self, 'cell_{}'.format(layer))
def reset_parameters(self):
for layer in range(self.num_layers):
cell = self.get_cell(layer)
cell.reset_parameters()
@staticmethod
def _forward_rnn(cell, input_, length, hx):
device = input_.device
max_time = input_.size(0)
output = []
for time in range(max_time):
h_next, c_next, cin_next = cell(input_=input_[time], hx=hx)
mask = (time < length).float().unsqueeze(1).expand_as(h_next).to(device)
h_next = (h_next * mask + hx[0] * (1 - mask)).to(device)
c_next = (c_next * mask + hx[1] * (1 - mask)).to(device)
cin_next = (cin_next * mask + hx[1] * (1 - mask)).to(device)
hx_next = (h_next, c_next, cin_next)
output.append(h_next)
hx = hx_next
output = torch.stack(output, 0)
return output, hx
def forward(self, input_, length=None, hx=None):
glovar.clear()
if self.batch_first:
input_ = input_.transpose(0, 1)
max_time, batch_size, _ = input_.size()
if length is None:
length = torch.LongTensor([max_time] * batch_size)
# if input_.is_cuda:
# device = input_.get_device()
# length = length.cuda(device)
if hx is None:
hx = input_.data.new(batch_size, self.hidden_size).zero_()
hx = (hx, hx, hx)
h_n = []
c_n = []
cin_n = []
layer_output = None
# APPLY WEIGHT DROP
if self.weight_drops:
for weight_drop in self.weight_drops:
weight_drop._setweights()
for layer in range(self.num_layers):
cell = self.get_cell(layer)
layer_output, (layer_h_n, layer_c_n, layer_cin_n) = NLSTM._forward_rnn(
cell=cell, input_=input_, length=length, hx=hx)
input_ = self.dropout_layer(layer_output)
h_n.append(layer_h_n)
c_n.append(layer_c_n)
cin_n.append(layer_cin_n)
output = layer_output
h_n = torch.stack(h_n, 0)
c_n = torch.stack(c_n, 0)
cin_n = torch.stack(cin_n, 0)
if self.batch_first:
output = output.transpose(0, 1)
return output, (h_n, c_n, cin_n)