torchinfo icon indicating copy to clipboard operation
torchinfo copied to clipboard

remove, or add a new flag for removing, "(recursive)" rows in the reported table

Open VinhLoiIT opened this issue 3 years ago • 1 comments

TL;DR: I just want the library to show model's parameters only and stop showing the "recursive" rows in sequential models such as LSTM, GRU, loop-based models.

For loop-based models such as LSTM, Sequence-to-sequence models, there are "recursive" rows shown on the summary table, as in this example. The "(recursion)" rows increase as the number of loops, which leads to a very long summary table.

AFAIK, there are two work-around method here:

  1. Set the number of loops smaller, maybe one or two. However, I have to run the code twice just for model's summary (i.e., one for viewing my summary table, one for actual running later). See the psuedo code below.
 class ConvAttnRNN(nn.Module):
    def __init__(self, max_length: int):
        super(ConvAttnRNN, self).__init__()
        self.lstm = nn.LSTMCell(20, 10),
        self.max_length = max_length
        self.eos_token = torch.tensor([1, 10])
        ...

    def forward(self, image, caption):
        for i in range(self.max_length):
            predict = self.lstm(caption)
            if predict == self.eos_token:
                break

# first run
model = ConvAttnRNN(5)

# second run
model = ConvAttnRNN(100)
  1. Additionally, in sequence-to-sequence problems, we usually introduce an "end-of-sequence" signal to early break the loop, and thus, the break statement will run based on the inputs, which is actually a random tensor. This lead to the fact that the reported table is different between runs, which is quite irritating.

In conclusion, I think we should add a parameter to whether or not to show the "(recursive)" rows in the summary function.

VinhLoiIT avatar Jul 14 '21 14:07 VinhLoiIT

Adding this option makes sense to me. I definitely don't think is should be the default though - there's a lot more value added by being able to see which are recursive layers and which are not. (For backwards compatibility reasons we also shouldn't change this default behavior).

I'll leave implementation detail comments on the PR itself.

TylerYep avatar Jul 20 '21 06:07 TylerYep