pytorch-summary
pytorch-summary copied to clipboard
fix: Fixed hook for DenseNet
Hello there,
I found recently that the main function does not work on the torchvision implementation of DenseNet.
from torchvision.models import densenet121
from torchsummary import summary
model = densenet121().eval().cuda()
summary(model, (3, 224, 224))
would yield
~/Documents/pytorch-summary/torchsummary/torchsummary.py in hook(module, input, output)
36 m_key = "%s-%i" % (class_name, module_idx + 1)
37 summary[m_key] = OrderedDict()
---> 38 summary[m_key]["input_shape"] = list(input[0].size())
39 summary[m_key]["input_shape"][0] = batch_size
40 if isinstance(output, (list, tuple)):
AttributeError: 'list' object has no attribute 'size'
The reason behind this is the _DenseLayer type in the architecture. Since it does not inherit from torch.nn.Sequential or torch.nn.ModuleList, the element is not skipped when hooking all object children.
This PR fixes it by checking if the current child has any children modules, before registering ahook.
Hope this helps! Cheers
👍
Hi,I have install your branch using pip install git+https://github.com/frgfm/pytorch-summary.git, but I still meet this issue "AttributeError: 'list' object has no attribute 'size'" when testing the densenet backbone.Could you do me a favor?
Hi @kaixinbear,
Thanks for pointing it out. Could you provide the exact running code you used to reproduce the error?
Also, just checking, when did you add flags / options to pip install git+ ?
By default, it would install the master branch. But since this is a fix/feature, my edit is on the densenet-fix branch (cf. the PR header).
Cheers!
just now,I rerun pip install git+https://github.com/frgfm/pytorch-summary.git
Collecting git+https://github.com/frgfm/pytorch-summary.git
Cloning https://github.com/frgfm/pytorch-summary.git to /tmp/pip-req-build-9fh9x80m
Running command git clone -q https://github.com/frgfm/pytorch-summary.git /tmp/pip-req-build-9fh9x80m
Requirement already satisfied (use --upgrade to upgrade): torchsummary==1.5.1 from git+https://github.com/frgfm/pytorch-summary.git in ./anaconda3/lib/python3.7/site-packages
Building wheels for collected packages: torchsummary
Building wheel for torchsummary (setup.py) ... done
Created wheel for torchsummary: filename=torchsummary-1.5.1-cp37-none-any.whl size=2850 sha256=1011ad3c5742d616c11a7b211663761d4ce739545519d23b0a4e62fcf8289120
Stored in directory: /tmp/pip-ephem-wheel-cache-3k2g7tqz/wheels/27/2e/88/8c4cca542c91043b0d6c5ca666e59a6abdeb8fe05e2a198db8
$ python
Python 3.7.4 (default, Aug 13 2019, 20:35:49)
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torchsummary
>>> from torchvision.models import densenet121
>>> from torchsummary import summary
>>> model = densenet121().eval().cuda()
>>> summary(model, (3, 224, 224))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torchsummary/torchsummary.py", line 72, in summary
model(*x)
File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torchvision/models/densenet.py", line 194, in forward
features = self.features(x)
File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 100, in forward
input = module(input)
File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torchvision/models/densenet.py", line 111, in forward
new_features = layer(features)
File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 534, in __call__
hook_result = hook(self, input, result)
File "/home/kaixin1/anaconda3/lib/python3.7/site-packages/torchsummary/torchsummary.py", line 19, in hook
summary[m_key]["input_shape"] = list(input[0].size())
AttributeError: 'list' object has no attribute 'size'
Thanks for the code @kaixinbear !
But as I suggested previously, the command pip install git+https://github.com/frgfm/pytorch-summary.git will install the master branch, which is in every point identical to the master branch of the original repo.
If you wish to check this PR modifications, you need to install from my densenet-fix branch. In order to do so, first, uninstall any existing installation and then
pip install git+https://github.com/frgfm/pytorch-summary@densenet-fix
Then run your code again, and it should be working
If not, please paste your torch version and your OS.
Well,I got it.Thanks for your kind reply.
Hi there, i do the following commands however still end up with the error as discussed above,
!pip install git+https://github.com/frgfm/pytorch-summary@densenet-fix from torchvision.models import densenet121 from torchsummary import summary model = densenet121().eval().cuda() summary(model, (3, 224, 224))
Any help would be appreciated
Hello @GregorKerr1996, Thanks for reporting it! My apologies, to properly install this PR content you should do the following:
pip uninstall torchsummary
git clone https://github.com/frgfm/pytorch-summary.git
cd pytorch-summary && git checkout densenet-fix
pip install -e .
Let me know if the error persists!
Besides, for personal use, I made a python library of my own adding ops estimations if you are interested. I benchmarked my implementation against torchvision models and the results are similar to other OPs estimation libraries.
Here it is: https://github.com/frgfm/torch-scan
@sksq96 @Naireen Summary with DenseNet has been a well known issue since long time. Could you people please check out this PR, and if it looks good, maybe merge it as well ? I'd love to get the summary from library installed through pip, than manually installing it over git, and switching to this branch.
@harshraj22 this repo doesn't seem to be maintained anymore but feel free to check out this: https://github.com/frgfm/torch-scan There are other features included