benchmark
benchmark copied to clipboard
script_lnlstm crashes when used bi-directional
Dear PyTorch team, Recently I tried out these versions of LSTMs with layer normalization, that I found through the PyTorch forums. Using the script_lnlstm with layer normalization however, causes the program to crash once loss.backward is called:
===
Traceback (most recent call last):
File "mydir/custom_lstms.py", line 508, in
return torch.stack(tensors, dim), backward
def unbind(self,
dim: int=0):
def backward(grad_outputs: List[Tensor]):
grad_self = torch.stack(grad_outputs, dim)
~~~~~~~~~~~ <--- HERE
return grad_self, None
return torch.unbind(self, dim), backward
def cat(tensors: List[Tensor],
dim: int=0):
size = len(tensors)
split_sizes = [0] * size
for i in range(size):
===== The following test function reproduces this error:
def test_script_stacked_lnlstm_bidirectional(seq_len, batch, input_size, hidden_size, num_layers): inp = torch.randn(seq_len, batch, input_size) states = [[LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)), LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))] for _ in range(num_layers)] print("inp.size(): " + str(inp.size())) print("states: " + str(states)) rnn = script_lnlstm(input_size, hidden_size, num_layers, bidirectional=True)
# just a smoke test
out, out_state = rnn(inp, states)
# This additional code, adding a loss function and using it
# to compute the loss and then calling the backward function,
# causes the program to crash
loss_function = torch.nn.L1Loss()
out_desired = torch.ones_like(out)
loss = loss_function(out, out_desired)
loss.backward()
====== I also had to make a fix to the "reverse" function to even get to this point:
def reverse(lst):
#print("len(lst: " + str(len(lst)))
#for element in lst:
# print("element.size(): " + str(element.size()))
# type: (List[Tensor]) -> List[Tensor]
# See: https://github.com/pytorch/pytorch/issues/27543
#return lst[::-1] # This fails with bidirectional LSTM
# Alternative implementation
copy_list = lst.copy()
copy_list.reverse()
return copy_list
#lst.reverse()
#return lst
===== The problem only occurs when the bidirectional=True is set, without that it works. Any ideas how to fix this?