torchinfo
torchinfo copied to clipboard
remove, or add a new flag for removing, "(recursive)" rows in the reported table
TL;DR: I just want the library to show model's parameters only and stop showing the "recursive" rows in sequential models such as LSTM, GRU, loop-based models.
For loop-based models such as LSTM, Sequence-to-sequence models, there are "recursive" rows shown on the summary table, as in this example. The "(recursion)" rows increase as the number of loops, which leads to a very long summary table.
AFAIK, there are two work-around method here:
- Set the number of loops smaller, maybe one or two. However, I have to run the code twice just for model's summary (i.e., one for viewing my summary table, one for actual running later). See the psuedo code below.
class ConvAttnRNN(nn.Module):
def __init__(self, max_length: int):
super(ConvAttnRNN, self).__init__()
self.lstm = nn.LSTMCell(20, 10),
self.max_length = max_length
self.eos_token = torch.tensor([1, 10])
...
def forward(self, image, caption):
for i in range(self.max_length):
predict = self.lstm(caption)
if predict == self.eos_token:
break
# first run
model = ConvAttnRNN(5)
# second run
model = ConvAttnRNN(100)
- Additionally, in sequence-to-sequence problems, we usually introduce an "end-of-sequence" signal to early break the loop, and thus, the
break
statement will run based on the inputs, which is actually a random tensor. This lead to the fact that the reported table is different between runs, which is quite irritating.
In conclusion, I think we should add a parameter to whether or not to show the "(recursive)" rows in the summary
function.
Adding this option makes sense to me. I definitely don't think is should be the default though - there's a lot more value added by being able to see which are recursive layers and which are not. (For backwards compatibility reasons we also shouldn't change this default behavior).
I'll leave implementation detail comments on the PR itself.