pytorch-summary icon indicating copy to clipboard operation
pytorch-summary copied to clipboard

input in Long type

Open sungjchi opened this issue 5 years ago • 7 comments

My model has an input in Long dtype which is passed to an torch.nn.Embedding layer. Since all the input data are created in torch.FloatTensor type, it creates runtime error when these inputs are passed to an embedding layer. Should there be an argument to choose which data type to pass to inputs?

In the code below, the third input is passed through an embedding layer and thus cause a runtime error.

(Pdb) summary(m,[(170,),(800,80),(170,),(800,)]) *** RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got CUDAType instead (while checking arguments for embedding)

sungjchi avatar Jun 27 '19 05:06 sungjchi

check out this pull request #89, just extended it...

DrStoop avatar Aug 12 '19 07:08 DrStoop

@sungjaechi this should now work with the new version of pytorch-summary (note that you'll have to clone it, not install it via pip). Let me know if it solves your issues

Naireen avatar Jan 13 '20 04:01 Naireen

Just for reference, had the same issue, using the version 1.5.1 on pip. I fixed it with tweaking the dtype in <me>.local/lib/python3.6/site-packages/torchsummary/torchsummary.py

Traceback (most recent call last):
  File "<my program>", line 114, in <module>
    torchsummary.summary(<my program>, input_size=(<my input>,))
  File "<me>.local/lib/python3.6/site-packages/torchsummary/torchsummary.py", line 72, in summary
    model(*x)
  File "<me>.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "<my program>", line 26, in forward
    x = self.embed(x)
  File "<me>.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "<me>.local/lib/python3.6/site-packages/torch/nn/modules/sparse.py", line 114, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "<me>.local/lib/python3.6/site-packages/torch/nn/functional.py", line 1484, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.FloatTensor instead (while checking arguments for embedding)

londumas avatar Feb 19 '20 08:02 londumas

No solutions for this? I have the same issue.

manuelblancovalentin avatar Apr 22 '21 16:04 manuelblancovalentin

Hi, why don't we merge this and publish to PyPI?

Leo-LiHao avatar Feb 16 '22 04:02 Leo-LiHao

Hi,

I’m interested, what are your research interests?

Hi, why don't we merge this and publish to PyPI?

— Reply to this email directly, view it on GitHub https://github.com/sksq96/pytorch-summary/issues/77#issuecomment-1041099900, or unsubscribe https://github.com/notifications/unsubscribe-auth/AOZ6RM5OL7WOPM7XCSX43CTU3MS7JANCNFSM4H3YERYQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

You are receiving this because you are subscribed to this thread.Message ID: @.***>

sreenithakasarapu avatar Feb 17 '22 02:02 sreenithakasarapu

Hi, I’m interested, what are your research interests? Hi, why don't we merge this and publish to PyPI? — Reply to this email directly, view it on GitHub <#77 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AOZ6RM5OL7WOPM7XCSX43CTU3MS7JANCNFSM4H3YERYQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub. You are receiving this because you are subscribed to this thread.Message ID: @.***>

I'm currently working on some NLP projects. I used torchinfo instead (as suggested in the readme) which solved my problem.

Leo-LiHao avatar Feb 17 '22 06:02 Leo-LiHao