pytorch-summary
pytorch-summary copied to clipboard
Not working with Double Precision Networks
When trying to use the summary
function on a network after calling net.double()
, an error is returned:
RuntimeError: Expected object of type torch.FloatTensor but found type torch.DoubleTensor for argument #2 'mat2'
The calling code is:
net = FCNet(obs_space=checkpoint['env'].observation_space.shape[0], action_space=checkpoint['env'].action_space.n,
shape=netConfig['shape'], activation=netConfig['activation'], dropout_rate=netConfig['dropout'],
bias=netConfig['bias'])
if TENSOR_TYPE == torch.FloatTensor:
net = net.float()
elif TENSOR_TYPE == torch.DoubleTensor:
net = net.double()
netArch = buildNetArch(net)
summary(net, input_size=(1, checkpoint['env'].observation_space.shape[0]))
I've gotten this error outside of torchsummary in the past when trying to pass in tensors of the wrong type. It seems like torchsummary passes a FloatTensor through the network to get the metrics, but this doesn't work on a network specifically made to use DoubleTensors. Is this correct?
If so, I could maybe try and fix it myself.
same issue with networks built off half-float tensors
check out the pull request #89, just added dtype
to torchsummary()