pytorch-summary icon indicating copy to clipboard operation
pytorch-summary copied to clipboard

Summary break with sequential()

Open ShuvenduRoy opened this issue 6 years ago • 1 comments

This model does not work with summary()

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

Traceback (most recent call last):
  File ".\test.py", line 118, in <module>
    summary(model, (3, 28, 28))
  File "C:\Users\bikas\Anaconda3\lib\site-packages\torchsummary\torchsummary.py", line 57, in summary
    model(x)
  File "C:\Users\bikas\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File ".\test.py", line 93, in forward
    return self.model(img)
  File "C:\Users\bikas\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\bikas\Anaconda3\lib\site-packages\torch\nn\modules\container.py", line 91, in forward
    input = module(input)
  File "C:\Users\bikas\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
    hook_result = hook(self, input, result)
  File "C:\Users\bikas\Anaconda3\lib\site-packages\torchsummary\torchsummary.py", line 26, in hook
    params += torch.prod(torch.LongTensor(list(module.weight.size())))
AttributeError: 'NoneType' object has no attribute 'size'

ShuvenduRoy avatar May 24 '18 12:05 ShuvenduRoy

Did you find a solution for this error?

rushin682 avatar Feb 20 '20 01:02 rushin682