pytorch_geometric
pytorch_geometric copied to clipboard
Support for other integer types by MessagePassing
🐛 Describe the bug
Why does PyG enforces edge_index
to be of type long? Certain graphs can work properly with int16
such as molecules, which will rarely surpass the 32767, unless batch sizes larger than 1000 are used. They have an average of 20 atoms. Instead, we could simply ensure that there are no negative numbers (which happen when the edge index overflows), or that the maximum of a specific datatype is not reached.
Also, some hardwares do not support long
, such as TPU and IPU, which are limited to int32
.
https://github.com/pyg-team/pytorch_geometric/blob/412ae53d0897660a2968283ed2cc60b5928c1229/torch_geometric/nn/conv/message_passing.py#L183
Environment
- PyG version: All
- PyTorch version: All
- OS: All
- Python version: All
- CUDA/cuDNN version: All
- How you installed PyTorch and PyG (
conda
,pip
, source): Any - Any other relevant information (e.g., version of
torch-scatter
): Any
This is mostly due to how PyTorch works, e.g., index_select
only works with indices of dtype=torch.long
. Not much we can do about it currently, sorry!
Since Pytorch 1.8.0, index_select
supports IntTensor as well as LongTensor, according to the docs.
Oh, my bad. I just tested with torch.short
and was convinced it is only working for torch.long
. I guess we can then start looking into supporting both torch.long
and torch.int
.
I can take a look at this: as a first attempt I'll try to relax the check for edge_index.dtype == torch.long
and add some additional tests in test/nn/conv/test_message_passing.py
and see what breaks.
Sounds good. Thanks! We also added support to make use of torch.scatter_reduce
(see utils/scatter.py
) which will help in this transition.
I can take a look at this: as a first attempt I'll try to relax the check for edge_index.dtype == torch.long and add some additional tests in test/nn/conv/test_message_passing.py and see what breaks.
Hi @howardjp, I've tried and it breaks at torch_scatter.scatter_add
.
I also tried to use torch.scatter_reduce
as suggested by @rusty1s, it currently does not support IntTensor
.
Finally, I also implemented scatter_reduce
based on torch.scatter
. Sadly, torch.scatter
does not support IntTensor
as well. This issue was mentioned in https://github.com/pytorch/pytorch/issues/61819 and https://github.com/pytorch/pytorch/issues/51323, which have not yet been resolved. It seems only index_select
supports IntTensor
.
Thanks for digging into this. It looks like we need to wait for PyTorch team to catch up, sorry :(
See #5281 for a proposal to relax the type assertion in the message passing interface. This patch effectively lets the execution backend decide which integer types to support when executing the aggregation step. That said, with the default CPU backend the error message changes to:
E RuntimeError: scatter(): Expected dtype int64 for index
https://github.com/pytorch/pytorch/issues/51323 seems fixed now, but not other nasty things like https://github.com/pytorch/pytorch/issues/56975 or https://github.com/pytorch/pytorch/issues/61819 ...
Could I add try..except...
code to aviod RuntimeError?
It will make redundant code. I don't know if such a modification is appropriate.
@vadimkantorov