pytorch-summary icon indicating copy to clipboard operation
pytorch-summary copied to clipboard

Not working with Double Precision Networks

Open cSchubes opened this issue 6 years ago • 2 comments

When trying to use the summary function on a network after calling net.double(), an error is returned:

RuntimeError: Expected object of type torch.FloatTensor but found type torch.DoubleTensor for argument #2 'mat2'

The calling code is:

net = FCNet(obs_space=checkpoint['env'].observation_space.shape[0], action_space=checkpoint['env'].action_space.n,
          shape=netConfig['shape'], activation=netConfig['activation'], dropout_rate=netConfig['dropout'],
          bias=netConfig['bias'])

if TENSOR_TYPE == torch.FloatTensor:
    net = net.float()
elif TENSOR_TYPE == torch.DoubleTensor:
    net = net.double()
netArch = buildNetArch(net)
summary(net, input_size=(1, checkpoint['env'].observation_space.shape[0]))

I've gotten this error outside of torchsummary in the past when trying to pass in tensors of the wrong type. It seems like torchsummary passes a FloatTensor through the network to get the metrics, but this doesn't work on a network specifically made to use DoubleTensors. Is this correct?

If so, I could maybe try and fix it myself.

cSchubes avatar Nov 07 '18 16:11 cSchubes

same issue with networks built off half-float tensors

NathanSegerlind avatar Apr 25 '19 19:04 NathanSegerlind

check out the pull request #89, just added dtype to torchsummary()

DrStoop avatar Aug 12 '19 07:08 DrStoop