pytorch-summary
pytorch-summary copied to clipboard
Why do I get '2' as batch size?
Hey,
This is a really great tool to visualize the model. However, I was trying to see how my decoder is working in the VAE and the input to the VAE is the latent space (dim = (2,2)
). However, when I get the output, I see an extra 2
there. Like this:
summary(decoder, (2,2))
Output is:
DECODER
torch.Size([2, 2, 2])
My decoder is initialized like this:
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
c = capacity
self.fc = nn.Linear(in_features=latent_dims, out_features=c*2*7*7)
self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
self.adapt = nn.AdaptiveMaxPool1d(input_len)
def forward(self, x):
print("DECODER")
print(x.shape) #1
x = self.fc(x)
x = x.reshape(-1,x.shape[0], x.shape[1])
x = self.adapt(x)
x = x.view(x.size(0), capacity*2, axis_transfer, axis_transfer) # unflatten batch of feature vectors to a batch of multi-channel feature maps
x = F.relu(self.conv2(x))
x = torch.sigmoid(self.conv1(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
return x
Do let me know.
torchsummary
would use a batch size 2 tensor to test the network, and get the information of each layer.
See the codes here:
https://github.com/sksq96/pytorch-summary/blob/011b2bd0ec7153d5842c1b37d1944fc6a7bf5feb/torchsummary/torchsummary.py#L58
Even you configure batch_size
in the input argument, this value is only used for calculating the flow size. The network is still tested by the batch size 2 tensor.
This behavior may cause errors when the network requires the input batch to be a specific value. To fix this problem, I modify the codes and let the testing tensor use batch_size
when this value is not None
, see
https://github.com/sksq96/pytorch-summary/pull/165/files#diff-ebda1cc7f304708e45ef4e19fb0484036eff8eb3c4b47a2886ca1cf0f731c0bbR118
Actually, it seems that the author has not maintained this package for a long time. I recommend you to try some alternatives like torchinfo
.
Thanks a lot. I wanted to ask why does it take '1' as batch size when I input a shape similar to an image, like (3,28,28)
? Because in that case, I don't see '2' as batch size.
I will definitely check out torchinfo
Thanks a lot. I wanted to ask why does it take '1' as batch size when I input a shape similar to an image, like
(3,28,28)
? Because in that case, I don't see '2' as batch size.I will definitely check out
torchinfo
I do not understand your question. In your previous posts, you have not mentioned any batch with a batch size of 1
.
By the way, I do not understand
I don't see '2' as batch size.
either. Because you have mentioned that your output is
DECODER
torch.Size([2, 2, 2])
Why do you say you do not see 2
as batch size? It is clear that the first element of the returned shape is 2
.
Here is a tip: if your are using
torchsummary.summary(..., input_size=...)
You should not let your input_size
become something like [3, 28, 28]
. That would cause errors. Instead, you should use ((3, 28, 28), )
or (3, 28, 28)
. The official implementation is quite unstable in some cases.
@cainmagi Thanks! torchinfo
works!