torchinfo icon indicating copy to clipboard operation
torchinfo copied to clipboard

Wrong recursive tag and duplicated/unordered entries

Open CappellatoAlessio opened this issue 2 years ago • 2 comments

Describe the bug summary wrongly reports modules as recursive and several entries are duplicated and not in any logical order.

To Reproduce Steps to reproduce the behavior:

  1. conda create -n recursivetest -c defaults -c pytorch -c conda-forge python pytorch torchvision torchaudio [cpuonly] torchinfo
  2. conda activate recursivetest
  3. Save recursivetest.py file (minimum example):
import torch
from torch import nn
from torchinfo import summary

class RecursiveTest(nn.Module):
    def __init__(self):
        super().__init__()
        self.out_conv0 = nn.Conv2d(3, 8, 5, padding='same')
        self.out_bn0 = nn.BatchNorm2d(8)

        self.block0 = nn.ModuleDict()
        for i in range(1, 4):
            self.block0.add_module(f"in_conv{i}", nn.Conv2d(8, 8, 3, padding="same", dilation=2 ** i))
            self.block0.add_module(f"in_bn{i}", nn.BatchNorm2d(8))

        self.block1 = nn.ModuleDict()
        for i in range(4, 7):
            self.block1.add_module(f"in_conv{i}", nn.Conv2d(8, 8, 3, padding="same", dilation=2 ** (7 - i)))
            self.block1.add_module(f"in_bn{i}", nn.BatchNorm2d(8))

        self.out_conv7 = nn.Conv2d(8, 1, 1, padding='same')
        self.out_bn7 = nn.BatchNorm2d(1)

    def forward(self, x):
        x = self.out_conv0(x)
        x = torch.relu(self.out_bn0(x))

        for i in range(1, 4):
            x = self.block0[f"in_conv{i}"](x)
            x = torch.relu(self.block0[f"in_bn{i}"](x))

        for i in range(4, 7):
            x = self.block1[f"in_conv{i}"](x)
            x = torch.relu(self.block1[f"in_bn{i}"](x))

        x = self.out_conv7(x)
        x = torch.relu(self.out_bn7(x))
        return x


if __name__ == '__main__':
    batch_size = 2
    data_shape = (3, 128, 128)
    random_data = torch.rand((batch_size, *data_shape))
    my_nn = RecursiveTest()
    print(my_nn)
    summary(my_nn, input_data=[random_data], row_settings=('depth', 'var_names'))
  1. Run python recursivetest.py
  2. See output:
RecursiveTest(
  (out_conv0): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1), padding=same)
  (out_bn0): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (block0): ModuleDict(
    (in_conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(2, 2))
    (in_bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (in_conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(4, 4))
    (in_bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (in_conv3): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(8, 8))
    (in_bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block1): ModuleDict(
    (in_conv4): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(8, 8))
    (in_bn4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (in_conv5): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(4, 4))
    (in_bn5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (in_conv6): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(2, 2))
    (in_bn6): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (out_conv7): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1), padding=same)
  (out_bn7): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
==========================================================================================
Layer (type (var_name):depth-idx)        Output Shape              Param #
==========================================================================================
RecursiveTest (RecursiveTest)            [2, 1, 128, 128]          --
├─Conv2d (out_conv0): 1-5                [2, 8, 128, 128]          (recursive)
├─BatchNorm2d (out_bn0): 1-6             [2, 8, 128, 128]          (recursive)
├─ModuleDict (block0): 1-3               --                        1,800
│    └─Conv2d (in_conv1): 2-7            [2, 8, 128, 128]          (recursive)
│    └─BatchNorm2d (in_bn1): 2-8         [2, 8, 128, 128]          (recursive)
│    └─Conv2d (in_conv2): 2-9            [2, 8, 128, 128]          (recursive)
│    └─BatchNorm2d (in_bn2): 2-10        [2, 8, 128, 128]          (recursive)
│    └─Conv2d (in_conv3): 2-11           [2, 8, 128, 128]          (recursive)
│    └─BatchNorm2d (in_bn3): 2-12        [2, 8, 128, 128]          (recursive)
├─ModuleDict (block1): 1-4               --                        1,800
├─Conv2d (out_conv0): 1-5                [2, 8, 128, 128]          (recursive)
├─BatchNorm2d (out_bn0): 1-6             [2, 8, 128, 128]          (recursive)
├─ModuleDict (block0): 1-3               --                        1,800
│    └─Conv2d (in_conv1): 2-7            [2, 8, 128, 128]          (recursive)
│    └─BatchNorm2d (in_bn1): 2-8         [2, 8, 128, 128]          (recursive)
│    └─Conv2d (in_conv2): 2-9            [2, 8, 128, 128]          (recursive)
│    └─BatchNorm2d (in_bn2): 2-10        [2, 8, 128, 128]          (recursive)
│    └─Conv2d (in_conv3): 2-11           [2, 8, 128, 128]          (recursive)
│    └─BatchNorm2d (in_bn3): 2-12        [2, 8, 128, 128]          (recursive)
├─ModuleDict (block1): 1-4               --                        1,800
│    └─Conv2d (in_conv4): 2-13           [2, 8, 128, 128]          584
│    └─BatchNorm2d (in_bn4): 2-14        [2, 8, 128, 128]          16
│    └─Conv2d (in_conv5): 2-15           [2, 8, 128, 128]          584
│    └─BatchNorm2d (in_bn5): 2-16        [2, 8, 128, 128]          16
│    └─Conv2d (in_conv6): 2-17           [2, 8, 128, 128]          584
│    └─BatchNorm2d (in_bn6): 2-18        [2, 8, 128, 128]          16
├─Conv2d (out_conv7): 1-7                [2, 1, 128, 128]          9
├─BatchNorm2d (out_bn7): 1-8             [2, 1, 128, 128]          2
==========================================================================================
Total params: 4,235
Trainable params: 4,235
Non-trainable params: 0
Total mult-adds (M): 212.37
==========================================================================================
Input size (MB): 0.39
Forward/backward pass size (MB): 13.11
Params size (MB): 0.01
Estimated Total Size (MB): 13.51
==========================================================================================

Expected behavior Each nn.Module (they are all used only once) should appear only once, possibly in traversal order and with depth-idx, Output Shape and Param # correctly reported in a logical way:

==========================================================================================
Layer (type (var_name):depth-idx)        Output Shape              Param #
==========================================================================================
RecursiveTest (RecursiveTest)            [2, 1, 128, 128]          --
├─Conv2d (out_conv0): 1-1                [2, 8, 128, 128]          608
├─BatchNorm2d (out_bn0): 1-2             [2, 8, 128, 128]          16
├─ModuleDict (block0): 1-3               --                        1,800
│    └─Conv2d (in_conv1): 2-1            [2, 8, 128, 128]          584
│    └─BatchNorm2d (in_bn1): 2-2         [2, 8, 128, 128]          16
│    └─Conv2d (in_conv2): 2-3            [2, 8, 128, 128]          584
│    └─BatchNorm2d (in_bn2): 2-4         [2, 8, 128, 128]          16
│    └─Conv2d (in_conv3): 2-5            [2, 8, 128, 128]          584
│    └─BatchNorm2d (in_bn3): 2-6         [2, 8, 128, 128]          16
├─ModuleDict (block1): 1-4               --                        1,800
│    └─Conv2d (in_conv4): 2-7            [2, 8, 128, 128]          584
│    └─BatchNorm2d (in_bn4): 2-8         [2, 8, 128, 128]          16
│    └─Conv2d (in_conv5): 2-9            [2, 8, 128, 128]          584
│    └─BatchNorm2d (in_bn5): 2-10        [2, 8, 128, 128]          16
│    └─Conv2d (in_conv6): 2-11           [2, 8, 128, 128]          584
│    └─BatchNorm2d (in_bn6): 2-12        [2, 8, 128, 128]          16
├─Conv2d (out_conv7): 1-5                [2, 1, 128, 128]          9
├─BatchNorm2d (out_bn7): 1-6             [2, 1, 128, 128]          2
==========================================================================================
Total params: 4,235
Trainable params: 4,235
Non-trainable params: 0
Total mult-adds (M): 212.37
==========================================================================================
Input size (MB): 0.39
Forward/backward pass size (MB): 13.11
Params size (MB): 0.01
Estimated Total Size (MB): 13.51
==========================================================================================

Desktop:

  • OS: Windows 10
  • Version 20H2

Additional context By adding a print(id(self)) in the forward(self, input) of nn.Conv2d (nn._BatchNorm) at ~\Anaconda3\envs\recursivetest\Lib\site-packages\torch\nn\modules\conv.py (batchnorm.py), all modules are confirmed to be used only once (all unique id's).

CappellatoAlessio avatar Aug 05 '22 17:08 CappellatoAlessio

And things get even worse if:

  1. Run summary(my_nn, input_data=[random_data], row_settings=('depth', 'var_names'), verbose=2)
  2. See output:
==========================================================================================
Layer (type (var_name):depth-idx)        Output Shape              Param #
==========================================================================================
RecursiveTest (RecursiveTest)            [2, 1, 128, 128]          --
├─Conv2d (out_conv0): 1-5                [2, 8, 128, 128]          (recursive)
│    └─weight                                                      ├─600
│    └─bias                                                        └─8
├─BatchNorm2d (out_bn0): 1-6             [2, 8, 128, 128]          (recursive)
│    └─weight                                                      ├─8
│    └─bias                                                        └─8
├─ModuleDict (block0): 1-3               --                        1,800
│    └─in_conv1.weight                                             ├─576
│    └─in_conv1.bias                                               ├─8
│    └─in_bn1.weight                                               ├─8
│    └─in_bn1.bias                                                 ├─8
│    └─in_conv2.weight                                             ├─576
│    └─in_conv2.bias                                               ├─8
│    └─in_bn2.weight                                               ├─8
│    └─in_bn2.bias                                                 ├─8
│    └─in_conv3.weight                                             ├─576
│    └─in_conv3.bias                                               ├─8
│    └─in_bn3.weight                                               ├─8
│    └─in_bn3.bias                                                 └─8
│    └─Conv2d (in_conv1): 2-7            [2, 8, 128, 128]          (recursive)
│    │    └─weight                                                 ├─576
│    │    └─bias                                                   └─8
│    └─BatchNorm2d (in_bn1): 2-8         [2, 8, 128, 128]          (recursive)
│    │    └─weight                                                 ├─8
│    │    └─bias                                                   └─8
│    └─Conv2d (in_conv2): 2-9            [2, 8, 128, 128]          (recursive)
│    │    └─weight                                                 ├─576
│    │    └─bias                                                   └─8
│    └─BatchNorm2d (in_bn2): 2-10        [2, 8, 128, 128]          (recursive)
│    │    └─weight                                                 ├─8
│    │    └─bias                                                   └─8
│    └─Conv2d (in_conv3): 2-11           [2, 8, 128, 128]          (recursive)
│    │    └─weight                                                 ├─576
│    │    └─bias                                                   └─8
│    └─BatchNorm2d (in_bn3): 2-12        [2, 8, 128, 128]          (recursive)
│    │    └─weight                                                 ├─8
│    │    └─bias                                                   └─8
├─ModuleDict (block1): 1-4               --                        1,800
│    └─in_conv4.weight                                             ├─576
│    └─in_conv4.bias                                               ├─8
│    └─in_bn4.weight                                               ├─8
│    └─in_bn4.bias                                                 ├─8
│    └─in_conv5.weight                                             ├─576
│    └─in_conv5.bias                                               ├─8
│    └─in_bn5.weight                                               ├─8
│    └─in_bn5.bias                                                 ├─8
│    └─in_conv6.weight                                             ├─576
│    └─in_conv6.bias                                               ├─8
│    └─in_bn6.weight                                               ├─8
│    └─in_bn6.bias                                                 └─8
├─Conv2d (out_conv0): 1-5                [2, 8, 128, 128]          (recursive)
│    └─weight                                                      ├─600
│    └─bias                                                        └─8
├─BatchNorm2d (out_bn0): 1-6             [2, 8, 128, 128]          (recursive)
│    └─weight                                                      ├─8
│    └─bias                                                        └─8
├─ModuleDict (block0): 1-3               --                        1,800
│    └─in_conv1.weight                                             ├─576
│    └─in_conv1.bias                                               ├─8
│    └─in_bn1.weight                                               ├─8
│    └─in_bn1.bias                                                 ├─8
│    └─in_conv2.weight                                             ├─576
│    └─in_conv2.bias                                               ├─8
│    └─in_bn2.weight                                               ├─8
│    └─in_bn2.bias                                                 ├─8
│    └─in_conv3.weight                                             ├─576
│    └─in_conv3.bias                                               ├─8
│    └─in_bn3.weight                                               ├─8
│    └─in_bn3.bias                                                 └─8
│    └─Conv2d (in_conv1): 2-7            [2, 8, 128, 128]          (recursive)
│    │    └─weight                                                 ├─576
│    │    └─bias                                                   └─8
│    └─BatchNorm2d (in_bn1): 2-8         [2, 8, 128, 128]          (recursive)
│    │    └─weight                                                 ├─8
│    │    └─bias                                                   └─8
│    └─Conv2d (in_conv2): 2-9            [2, 8, 128, 128]          (recursive)
│    │    └─weight                                                 ├─576
│    │    └─bias                                                   └─8
│    └─BatchNorm2d (in_bn2): 2-10        [2, 8, 128, 128]          (recursive)
│    │    └─weight                                                 ├─8
│    │    └─bias                                                   └─8
│    └─Conv2d (in_conv3): 2-11           [2, 8, 128, 128]          (recursive)
│    │    └─weight                                                 ├─576
│    │    └─bias                                                   └─8
│    └─BatchNorm2d (in_bn3): 2-12        [2, 8, 128, 128]          (recursive)
│    │    └─weight                                                 ├─8
│    │    └─bias                                                   └─8
├─ModuleDict (block1): 1-4               --                        1,800
│    └─in_conv4.weight                                             ├─576
│    └─in_conv4.bias                                               ├─8
│    └─in_bn4.weight                                               ├─8
│    └─in_bn4.bias                                                 ├─8
│    └─in_conv5.weight                                             ├─576
│    └─in_conv5.bias                                               ├─8
│    └─in_bn5.weight                                               ├─8
│    └─in_bn5.bias                                                 ├─8
│    └─in_conv6.weight                                             ├─576
│    └─in_conv6.bias                                               ├─8
│    └─in_bn6.weight                                               ├─8
│    └─in_bn6.bias                                                 └─8
│    └─Conv2d (in_conv4): 2-13           [2, 8, 128, 128]          584
│    │    └─weight                                                 ├─576
│    │    └─bias                                                   └─8
│    └─BatchNorm2d (in_bn4): 2-14        [2, 8, 128, 128]          16
│    │    └─weight                                                 ├─8
│    │    └─bias                                                   └─8
│    └─Conv2d (in_conv5): 2-15           [2, 8, 128, 128]          584
│    │    └─weight                                                 ├─576
│    │    └─bias                                                   └─8
│    └─BatchNorm2d (in_bn5): 2-16        [2, 8, 128, 128]          16
│    │    └─weight                                                 ├─8
│    │    └─bias                                                   └─8
│    └─Conv2d (in_conv6): 2-17           [2, 8, 128, 128]          584
│    │    └─weight                                                 ├─576
│    │    └─bias                                                   └─8
│    └─BatchNorm2d (in_bn6): 2-18        [2, 8, 128, 128]          16
│    │    └─weight                                                 ├─8
│    │    └─bias                                                   └─8
├─Conv2d (out_conv7): 1-7                [2, 1, 128, 128]          9
│    └─weight                                                      ├─8
│    └─bias                                                        └─1
├─BatchNorm2d (out_bn7): 1-8             [2, 1, 128, 128]          2
│    └─weight                                                      ├─1
│    └─bias                                                        └─1
==========================================================================================
Total params: 4,235
Trainable params: 4,235
Non-trainable params: 0
Total mult-adds (M): 212.37
==========================================================================================
Input size (MB): 0.39
Forward/backward pass size (MB): 13.11
Params size (MB): 0.01
Estimated Total Size (MB): 13.51
==========================================================================================

CappellatoAlessio avatar Aug 05 '22 17:08 CappellatoAlessio

Thanks for reporting this issue! PRs to fix this are much appreciated.

TylerYep avatar Aug 05 '22 18:08 TylerYep

Hi,

I have a potential solution that solves this issue and works on passes test cases. While working on this solution, I have found some other cases that are problematic for both the current implementation and my solution. But, to be really certain about my solution, I need to ask a few general question, e.g., regarding add_missing_layers() function in torchinfo.py.

How should I proceed? Do I continue this general discussion here or open another issue for it?

mert-kurttutan avatar Aug 25 '22 16:08 mert-kurttutan

Feel free to open the PR (even if it is a draft) and we can discuss the issues on the PR itself

TylerYep avatar Aug 25 '22 17:08 TylerYep