pytorch-summary icon indicating copy to clipboard operation
pytorch-summary copied to clipboard

`torchsummary()` extensions by `input_initializer`, `dtype`

Open DrStoop opened this issue 5 years ago • 0 comments

Background

I needed to summarize the OpenAIGPTDoubleHeadsModel in huggingface/pytorch-transformers which takes as (dummy) input multiple torch.zeors() tensors with dtype=torch.int64. This is currently not supported in the pytorch-summary tool, so I extended it.

Extensions:

  • added dtype to torchsummary() input variables
  • added input_initializer to torchsummary() input variables

Bugfix:

  • changed batch_size default value from -1 to 2 so it is acutally uses and returns a correct total_input_size
  • total_input returned TypeError:
 File "/home/developer/AmI/pytorch-summary/torchsummary/torchsummary.py", line 96, in ### summary
    total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
  File "/conda/envs/rapids/lib/python3.6/site-packages/numpy/core/fromnumeric.py", line 2772, in prod
      initial=initial)
  File "/conda/envs/rapids/lib/python3.6/site-packages/numpy/core/fromnumeric.py", line 86, in _wrapreduction
    return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
TypeError: can't multiply sequence by non-int of type 'tuple'

Testing:

summary(model=model, input_size=[(2, 78), (2,), (2, 78), ()])

To analyze the input tensors you can use this code snippet:

dummy_input = next(train_dataloader.__iter__())
for i, tensor in enumerate(dummy_input):
    print("dummy_input[{}]:".format(i))
    print(tensor.shape)
    print(tensor.dtype)
    print("")

P.S.:

Thanks for the tool :+1:, guess I'll be using it quite often... it's nice & simple with great overview!

DrStoop avatar Aug 12 '19 07:08 DrStoop