pytorch-lstm-by-hand
pytorch-lstm-by-hand copied to clipboard
Hi, Is there any support for Bi-LSTM
Hi, thanks for the helpful work.
Coudl I ask if any plan for supporting bidirectional lstm with custom stacks?
Hi. I was not thinking on it, but if it is helpful I might as well support it.
Also, feel free to PR with the feature if you want.
Hi, I tried to implement on myself. However, I can not figure out the output format.
import torch
import torch.nn as nn
import pdb
class CustomLSTM(nn.LSTM):
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False):
# proj_size=0 is available from Pytorch 1.8
super(CustomLSTM, self).__init__(input_size, hidden_size, num_layers=num_layers,
bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
def forward(self, x, init_states=None, exporting_onnx=False):
if exporting_onnx:
assert self.num_layers == 1
bs, seq, _ = x.size() if self.batch_first else (x.size(1), x.size(0), x.size(2))
sz = self.hidden_size
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), torch.zeros(bs, self.hidden_size).to(x.device))
hidden_seq_forward = []
for t in range(seq):
x_t = x[:, t, :] if self.batch_first else x[t, :, :]
i_t = x_t @ self.weight_ih_l0[sz*0:sz*1,:].transpose(0, 1) + self.bias_ih_l0[sz*0:sz*1] + \
h_t @ self.weight_hh_l0[sz*0:sz*1,:].transpose(0, 1) + self.bias_hh_l0[sz*0:sz*1]
f_t = x_t @ self.weight_ih_l0[sz*1:sz*2,:].transpose(0, 1) + self.bias_ih_l0[sz*1:sz*2] + \
h_t @ self.weight_hh_l0[sz*1:sz*2,:].transpose(0, 1) + self.bias_hh_l0[sz*1:sz*2]
g_t = x_t @ self.weight_ih_l0[sz*2:sz*3,:].transpose(0, 1) + self.bias_ih_l0[sz*2:sz*3] + \
h_t @ self.weight_hh_l0[sz*2:sz*3,:].transpose(0, 1) + self.bias_hh_l0[sz*2:sz*3]
o_t = x_t @ self.weight_ih_l0[sz*3:sz*4,:].transpose(0, 1) + self.bias_ih_l0[sz*3:sz*4] + \
h_t @ self.weight_hh_l0[sz*3:sz*4,:].transpose(0, 1) + self.bias_hh_l0[sz*3:sz*4]
i_t = torch.sigmoid(i_t)
f_t = torch.sigmoid(f_t)
g_t = torch.tanh(g_t)
o_t = torch.sigmoid(o_t)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq_forward.append(h_t.unsqueeze(0))
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), torch.zeros(bs, self.hidden_size).to(x.device))
hidden_seq_reverse = []
for t in list(reversed(range(seq))):
x_t = x[:, t, :] if self.batch_first else x[t, :, :]
i_t = x_t @ self.weight_ih_l0_reverse[sz*0:sz*1,:].transpose(0, 1) + self.bias_ih_l0_reverse[sz*0:sz*1] + \
h_t @ self.weight_hh_l0_reverse[sz*0:sz*1,:].transpose(0, 1) + self.bias_hh_l0_reverse[sz*0:sz*1]
f_t = x_t @ self.weight_ih_l0_reverse[sz*1:sz*2,:].transpose(0, 1) + self.bias_ih_l0_reverse[sz*1:sz*2] + \
h_t @ self.weight_hh_l0_reverse[sz*1:sz*2,:].transpose(0, 1) + self.bias_hh_l0_reverse[sz*1:sz*2]
g_t = x_t @ self.weight_ih_l0_reverse[sz*2:sz*3,:].transpose(0, 1) + self.bias_ih_l0_reverse[sz*2:sz*3] + \
h_t @ self.weight_hh_l0_reverse[sz*2:sz*3,:].transpose(0, 1) + self.bias_hh_l0_reverse[sz*2:sz*3]
o_t = x_t @ self.weight_ih_l0_reverse[sz*3:sz*4,:].transpose(0, 1) + self.bias_ih_l0_reverse[sz*3:sz*4] + \
h_t @ self.weight_hh_l0_reverse[sz*3:sz*4,:].transpose(0, 1) + self.bias_hh_l0_reverse[sz*3:sz*4]
i_t = torch.sigmoid(i_t)
f_t = torch.sigmoid(f_t)
g_t = torch.tanh(g_t)
o_t = torch.sigmoid(o_t)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t) # [bs * self.hidden_size]
hidden_seq_reverse.append(h_t.unsqueeze(0))
# stack hidden_seq_forward and hidden_seq_reverse to hidden_seq
hidden_seq = torch.cat(hidden_seq, dim=0) # [seq, bs, self.hidden_size]
if self.batch_first:
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (_, _)
else:
return super().forward(x)
if __name__ == "__main__":
model = CustomLSTM(100, 60, bidirectional=True)
x = torch.rand(512, 10, 100)
model.eval()
y1, (hn, cn) = model(x, None, False)
print(y1.shape)
y2, (hn, cn) = model(x, None, True)
print(y2.shape)
pdb.set_trace()
Could I ask for suggestion around # stack hidden_seq_forward and hidden_seq_reverse to hidden_seq
if I employ
# stack hidden_seq_forward and hidden_seq_reverse to hidden_seq
hidden_seq_forward = torch.cat(hidden_seq_forward, dim=0) # [seq, bs, self.hidden_size]
hidden_seq_reverse = torch.cat(hidden_seq_reverse, dim=0) # [seq, bs, self.hidden_size]
print(hidden_seq_forward.shape, hidden_seq_reverse.shape)
hidden_seq = torch.cat([hidden_seq_forward, hidden_seq_reverse], dim=2)
print(hidden_seq.shape)
seems y1 == y2 in the main gives a lot of False