pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

Support for other integer types by MessagePassing

Open DomInvivo opened this issue 2 years ago • 8 comments

🐛 Describe the bug

Why does PyG enforces edge_index to be of type long? Certain graphs can work properly with int16 such as molecules, which will rarely surpass the 32767, unless batch sizes larger than 1000 are used. They have an average of 20 atoms. Instead, we could simply ensure that there are no negative numbers (which happen when the edge index overflows), or that the maximum of a specific datatype is not reached.

Also, some hardwares do not support long, such as TPU and IPU, which are limited to int32.

https://github.com/pyg-team/pytorch_geometric/blob/412ae53d0897660a2968283ed2cc60b5928c1229/torch_geometric/nn/conv/message_passing.py#L183

Environment

  • PyG version: All
  • PyTorch version: All
  • OS: All
  • Python version: All
  • CUDA/cuDNN version: All
  • How you installed PyTorch and PyG (conda, pip, source): Any
  • Any other relevant information (e.g., version of torch-scatter): Any

DomInvivo avatar Jul 29 '22 14:07 DomInvivo

This is mostly due to how PyTorch works, e.g., index_select only works with indices of dtype=torch.long. Not much we can do about it currently, sorry!

rusty1s avatar Jul 30 '22 15:07 rusty1s

Since Pytorch 1.8.0, index_select supports IntTensor as well as LongTensor, according to the docs.

DomInvivo avatar Jul 30 '22 18:07 DomInvivo

Oh, my bad. I just tested with torch.short and was convinced it is only working for torch.long. I guess we can then start looking into supporting both torch.long and torch.int.

rusty1s avatar Jul 30 '22 18:07 rusty1s

I can take a look at this: as a first attempt I'll try to relax the check for edge_index.dtype == torch.long and add some additional tests in test/nn/conv/test_message_passing.py and see what breaks.

hatemhelal avatar Aug 15 '22 18:08 hatemhelal

Sounds good. Thanks! We also added support to make use of torch.scatter_reduce (see utils/scatter.py) which will help in this transition.

rusty1s avatar Aug 15 '22 19:08 rusty1s

I can take a look at this: as a first attempt I'll try to relax the check for edge_index.dtype == torch.long and add some additional tests in test/nn/conv/test_message_passing.py and see what breaks.

Hi @howardjp, I've tried and it breaks at torch_scatter.scatter_add.

I also tried to use torch.scatter_reduce as suggested by @rusty1s, it currently does not support IntTensor.

Finally, I also implemented scatter_reduce based on torch.scatter. Sadly, torch.scatter does not support IntTensor as well. This issue was mentioned in https://github.com/pytorch/pytorch/issues/61819 and https://github.com/pytorch/pytorch/issues/51323, which have not yet been resolved. It seems only index_select supports IntTensor.

EdisonLeeeee avatar Aug 18 '22 09:08 EdisonLeeeee

Thanks for digging into this. It looks like we need to wait for PyTorch team to catch up, sorry :(

rusty1s avatar Aug 18 '22 16:08 rusty1s

See #5281 for a proposal to relax the type assertion in the message passing interface. This patch effectively lets the execution backend decide which integer types to support when executing the aggregation step. That said, with the default CPU backend the error message changes to:

E           RuntimeError: scatter(): Expected dtype int64 for index

hatemhelal avatar Aug 26 '22 10:08 hatemhelal

https://github.com/pytorch/pytorch/issues/51323 seems fixed now, but not other nasty things like https://github.com/pytorch/pytorch/issues/56975 or https://github.com/pytorch/pytorch/issues/61819 ...

vadimkantorov avatar Aug 03 '23 00:08 vadimkantorov

Could I add try..except... code to aviod RuntimeError? It will make redundant code. I don't know if such a modification is appropriate. @vadimkantorov

robertparley avatar Sep 06 '23 07:09 robertparley