pytorch-summary
pytorch-summary copied to clipboard
Fix the multi-output, dict-input, parameter counting and calculation overflow problem.
Update report
- Fix the bug of parameter number calculation when there are more than one output variables, including both sequence case and dict case (mentioned in #162).
- Make multiple output variables split into multiple lines.
- Remove the last line break of
summary_string(). - Enable argument
deviceto accept both str andtorch.device. - Fix a bug when the model requires
batch_sizeto be a specific number. - Fix a bug caused by multiple input cases when
dtypes=None. - Add text auto wrap when the layer name is too long.
- Support counting all parameters instead of
weightandbias(a different solution of #142, #148). - Drop the
np.sum/prodto fix the overflow problem during calculating the total size (mentioned in #158). - Fix the bug caused by layers with dict input values (mentioned in #162).
- Add docstring.
Example for verifying this update
The following code is not compatible with the base repository:
import torch
import torch.nn as nn
from torchsummary import summary
class VeryLongNameSimpleMultiConv(nn.Module):
def __init__(self):
super(VeryLongNameSimpleMultiConv, self).__init__()
self.features_1 = nn.Sequential(
nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
)
self.features_2 = nn.Sequential(
nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
)
def forward(self, x):
x1 = self.features_1(x)
x2 = self.features_2(x)
return x1, x2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VeryLongNameSimpleMultiConv().to(device)
summary(model, (1, 16, 16))
Now the output is:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 1, 16, 16] 10
ReLU-2 [-1, 1, 16, 16] 0
Conv2d-3 [-1, 2, 16, 16] 20
ReLU-4 [-1, 2, 16, 16] 0
VeryLong...ltiConv-5 [-1, 1, 16, 16] 0
[-1, 2, 16, 16]
================================================================
Total params: 30
Trainable params: 30
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 0.00
Estimated Total Size (MB): 0.02
----------------------------------------------------------------