pytorch-summary
pytorch-summary copied to clipboard
autoencoder cannot be used
When I try to run the summary for a non-convolutional autoencoder architecture:
import torch.nn as nn
import torch
from torch.autograd import Variable
import sys
from torchsummary import summary
class Autoencoder(nn.Module):
def __init__(self,input_dim, output_dim, n_layers=4, size_ratio=0.5, activation='relu'):
super(Autoencoder, self).__init__()
def get_activation(activation):
if(activation=='relu'):
return nn.ReLU(True)
elif(activation=='tanh'):
return nn.Tanh()
elif(activation=='sigmoid'):
return nn.Sigmoid()
elif(activation=='leakyrelu'):
return torch.nn.LeakyReLU()
encoder_layers = []
in_size_list = [input_dim]
out_size_list = [output_dim]
for i in range(int(n_layers/2)):
out_size_list += [int(out_size_list[i]*size_ratio)]
encoder_layers += [nn.Linear(in_size_list[i], out_size_list[i+1])]
encoder_layers += [get_activation(activation)]
in_size_list += [out_size_list[i+1]]
decoder_layers = []
out_size_list.reverse()
for i in range(int(n_layers/2)-1):
decoder_layers += [nn.Linear(out_size_list[i], out_size_list[i+1])]
decoder_layers += [get_activation(activation)]
decoder_layers += [nn.Linear(out_size_list[-2], output_dim)]
decoder_layers += [get_activation('sigmoid')]
self.encoder = nn.Sequential(*encoder_layers)
self.decoder = nn.Sequential(*decoder_layers)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
autoencoder = Autoencoder(input_dim=4396, output_dim=4396, n_layers=4, size_ratio=0.5, activation='sigmoid').cuda()
summary(autoencoder, (4396))
I get the error:
Traceback (most recent call last):
File "test.py", line 54, in <module>
summary(autoencoder, (4396))
File "/home/raqueld/.local/lib/python3.6/site-packages/torchsummary/torchsummary.py", line 60, in summary
x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
TypeError: 'int' object is not iterable
Does this package support only CNNs and RNNs? what about feedforward neural networks and autoencoders?
Your invocation is not correct. Change the codes like:
summary(autoencoder, (4396,))