pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

Issue in src.new_zeros(size).scatter_add_(dim, index, src) for Heterogeneous Model

Open pauvilasoler opened this issue 1 year ago • 4 comments

🐛 Describe the bug

Hi,

I am running into some issues when trying to run a Heterogeneous GNN on a custom-made dataset.

Context:

Basically, my dataset is a list of HeteroData (i.e. heterogeneous graphs) objects each of which has 1 node of the type 'Ego' and 25 nodes of the type 'Alter'. The edge types are ('Alter', 'to', 'Ego') (of which there are 25 for each graph > each of the 25 'Alters' is connected to the 'Ego') and ('Alter', 'to', 'Alter') (of which there are variable numbers for each graph). The first of these edge types have attributes whereas the latter do not.

More specifically, this is what the data looks like:

image

Regarding the model, I am using the Heterogeneous Convolution Wrapper (HeteroConv) that you can see below:

from torch_geometric.nn.conv import HeteroConv, GATConv, GCNConv, GraphConv
import torch.nn as nn
from torch.nn import Module, Linear, ReLU
from torch.optim import Adam
import torch.nn.functional as F
import torch

class Model(nn.Module):
    def __init__(self, n_conv_layers, hidden_channels, out_channels):
        super().__init__()
        
        self.gat = GATConv((-1, -1), hidden_channels, add_self_loops=False, aggr='mean')
        self.gcn = GraphConv(-1, hidden_channels)
        
        self.convs = nn.ModuleList()

        for i in range(n_conv_layers):
            hetero_conv = HeteroConv({('Alter', 'to', 'Alter'): self.gcn, ('Alter', 'to', 'Ego'): self.gat}, aggr="sum")
            self.convs.append(hetero_conv)

        self.relu = ReLU()

        self.linear = Linear(hidden_channels, out_channels)

        self.optimizer = Adam(params=self.parameters())

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)



    def forward(self, x_dict, edge_index_dict, edge_attributes_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict, edge_attributes_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
        return self.linear(x_dict['Ego']) # predictions are made on the Ego



    def train_model(self, train_data):
        self.train()
        self.optimizer.zero_grad()
        preds = [] # predictions
        ys = []
        losses = []
        i = 0
        for data in train_data:
            i = i
            out = self(data.x_dict, data.edge_indices_dict, data.edge_attributes_dict)
            loss = self.config.loss(out['Ego'], data['Ego'].y)
            loss.backward()
            self.optimizer.step()
            losses.append(loss)
            preds.append(out['Ego'])
            ys.append(data['Ego'].y)
            i = i + 1
        return losses, preds, ys



    def test_model(self, test_data):
        self.test() 
        self.optimizer.zero_grad()
        preds = []
        ys = []
        losses = []
        for data in test_data:
            out = self(data.x_dict, data.edge_index_dict, data.edge_attributes_dict)
            loss = self.config.loss(out['Ego'], data['Ego'].y)
            losses.append(loss)
            preds.append(out['Ego'])
            ys.append(data['Ego'].y)
        return losses, preds, ys

However, the issue arises when training the model as in:


train_data = graphs  # note that graphs is the dataset and is a list of HeteroData objects like the one above

model = Model(n_conv_layers=1, hidden_channels=64, out_channels=1)

model.train_model(train_data)

Here is the error message:

image

I have noticed that a similar issue was raised in https://github.com/pyg-team/pytorch_geometric/issues/4588 but the solutions provided there are not working for my (Heterogeneous) case.

On top of this, an additional exception is raised which would seem to me to be related:

image

I would appreciate any help or ideas.

Thanks a lot!

Versions

Environment (yaml)

name: predicting-GNNs channels:

  • defaults

dependencies:

  • anyio=3.5.0=py39haa95532_0
  • argon2-cffi=21.3.0=pyhd3eb1b0_0
  • argon2-cffi-bindings=21.2.0=py39h2bbff1b_0
  • asttokens=2.0.5=pyhd3eb1b0_0
  • async-lru=2.0.4=py39haa95532_0
  • attrs=23.1.0=py39haa95532_0
  • babel=2.11.0=py39haa95532_0
  • backcall=0.2.0=pyhd3eb1b0_0
  • beautifulsoup4=4.12.2=py39haa95532_0
  • blas=1.0=mkl
  • bleach=4.1.0=pyhd3eb1b0_0
  • bottleneck=1.3.5=py39h080aedc_0
  • brotli=1.0.9=h2bbff1b_7
  • brotli-bin=1.0.9=h2bbff1b_7
  • brotli-python=1.0.9=py39hd77b12b_7
  • bzip2=1.0.8=he774522_0
  • ca-certificates=2023.08.22=haa95532_0
  • certifi=2023.11.17=py39haa95532_0
  • cffi=1.16.0=py39h2bbff1b_0
  • colorama=0.4.6=py39haa95532_0
  • comm=0.1.2=py39haa95532_0
  • contourpy=1.2.0=py39h59b6b97_0
  • cryptography=41.0.7=py39h89fc84f_0
  • cycler=0.11.0=pyhd3eb1b0_0
  • debugpy=1.6.7=py39hd77b12b_0
  • defusedxml=0.7.1=pyhd3eb1b0_0
  • exceptiongroup=1.0.4=py39haa95532_0
  • executing=0.8.3=pyhd3eb1b0_0
  • filelock=3.13.1=py39haa95532_0
  • fonttools=4.25.0=pyhd3eb1b0_0
  • freetype=2.12.1=ha860e81_0
  • fsspec=2023.10.0=py39haa95532_0
  • giflib=5.2.1=h8cc25b3_3
  • gmpy2=2.1.2=py39h7f96b67_0
  • h5py=3.9.0=py39hfc34f40_0
  • hdf5=1.12.1=h51c971a_3
  • icc_rt=2022.1.0=h6049295_2
  • icu=73.1=h6c2663c_0
  • importlib-metadata=6.0.0=py39haa95532_0
  • importlib_metadata=6.0.0=hd3eb1b0_0
  • importlib_resources=6.1.0=py39haa95532_0
  • intel-openmp=2023.1.0=h59b6b97_46320
  • ipykernel=6.25.0=py39h9909e9c_0
  • ipython=8.15.0=py39haa95532_0
  • jedi=0.18.1=py39haa95532_1
  • jinja2=3.1.2=py39haa95532_0
  • jpeg=9e=h2bbff1b_1
  • json5=0.9.6=pyhd3eb1b0_0
  • jsonschema=4.19.2=py39haa95532_0
  • jsonschema-specifications=2023.7.1=py39haa95532_0
  • jupyter-lsp=2.2.0=py39haa95532_0
  • jupyter_client=8.6.0=py39haa95532_0
  • jupyter_core=5.5.0=py39haa95532_0
  • jupyter_events=0.8.0=py39haa95532_0
  • jupyter_server=2.10.0=py39haa95532_0
  • jupyter_server_terminals=0.4.4=py39haa95532_1
  • jupyterlab=4.0.8=py39haa95532_0
  • jupyterlab_pygments=0.1.2=py_0
  • jupyterlab_server=2.25.1=py39haa95532_0
  • kiwisolver=1.4.4=py39hd77b12b_0
  • krb5=1.20.1=h5b6d351_0
  • lerc=3.0=hd77b12b_0
  • libbrotlicommon=1.0.9=h2bbff1b_7
  • libbrotlidec=1.0.9=h2bbff1b_7
  • libbrotlienc=1.0.9=h2bbff1b_7
  • libclang=14.0.6=default_hb5a9fac_1
  • libclang13=14.0.6=default_h8e68704_1
  • libdeflate=1.17=h2bbff1b_1
  • libffi=3.4.4=hd77b12b_0
  • libpng=1.6.39=h8cc25b3_0
  • libpq=12.15=h906ac69_1
  • libsodium=1.0.18=h62dcd97_0
  • libtiff=4.5.1=hd77b12b_0
  • libuv=1.44.2=h2bbff1b_0
  • libwebp=1.3.2=hbc33d0d_0
  • libwebp-base=1.3.2=h2bbff1b_0
  • lz4-c=1.9.4=h2bbff1b_0
  • markupsafe=2.1.1=py39h2bbff1b_0
  • matplotlib=3.8.0=py39haa95532_0
  • matplotlib-base=3.8.0=py39h4ed8f06_0
  • matplotlib-inline=0.1.6=py39haa95532_0
  • mistune=2.0.4=py39haa95532_0
  • mkl=2023.1.0=h6b88ed4_46358
  • mkl-service=2.4.0=py39h2bbff1b_1
  • mkl_fft=1.3.8=py39h2bbff1b_0
  • mkl_random=1.2.4=py39h59b6b97_0
  • mpc=1.1.0=h7edee0f_1
  • mpfr=4.0.2=h62dcd97_1
  • mpir=3.0.0=hec2e145_1
  • mpmath=1.3.0=py39haa95532_0
  • munkres=1.1.4=py_0
  • nbclient=0.8.0=py39haa95532_0
  • nbconvert=7.10.0=py39haa95532_0
  • nbformat=5.9.2=py39haa95532_0
  • nest-asyncio=1.5.6=py39haa95532_0
  • networkx=3.1=py39haa95532_0
  • ninja=1.10.2=haa95532_5
  • ninja-base=1.10.2=h6d14046_5
  • notebook=7.0.6=py39haa95532_0
  • notebook-shim=0.2.3=py39haa95532_0
  • numexpr=2.8.7=py39h2cd9be0_0
  • numpy=1.26.2=py39h055cbcc_0
  • numpy-base=1.26.2=py39h65a83cf_0
  • openjpeg=2.4.0=h4fc8c34_0
  • openssl=3.0.12=h2bbff1b_0
  • overrides=7.4.0=py39haa95532_0
  • packaging=23.1=py39haa95532_0
  • pandas=1.3.5=py39h6214cd6_0
  • pandocfilters=1.5.0=pyhd3eb1b0_0
  • parso=0.8.3=pyhd3eb1b0_0
  • pickleshare=0.7.5=pyhd3eb1b0_1003
  • pillow=10.0.1=py39h045eedc_0
  • pip=23.3.1=py39haa95532_0
  • platformdirs=3.10.0=py39haa95532_0
  • ply=3.11=py39haa95532_0
  • prometheus_client=0.14.1=py39haa95532_0
  • prompt-toolkit=3.0.36=py39haa95532_0
  • pure_eval=0.2.2=pyhd3eb1b0_0
  • pycparser=2.21=pyhd3eb1b0_0
  • pygments=2.15.1=py39haa95532_1
  • pyopenssl=23.2.0=py39haa95532_0
  • pypdf2=2.10.5=py39haa95532_0
  • pyqt=5.15.10=py39hd77b12b_0
  • pyqt5-sip=12.13.0=py39h2bbff1b_0
  • pysocks=1.7.1=py39haa95532_0
  • python=3.9.18=h1aa4202_0
  • python-dateutil=2.8.2=pyhd3eb1b0_0
  • python-fastjsonschema=2.16.2=py39haa95532_0
  • python-json-logger=2.0.7=py39haa95532_0
  • pytorch=2.1.0=cpu_py39hb0bdfb8_0
  • pytz=2023.3.post1=py39haa95532_0
  • pywin32=305=py39h2bbff1b_0
  • pywinpty=2.0.10=py39h5da7b33_0
  • pyyaml=6.0.1=py39h2bbff1b_0
  • pyzmq=25.1.0=py39hd77b12b_0
  • qt-main=5.15.2=h19c9488_10
  • referencing=0.30.2=py39haa95532_0
  • requests=2.31.0=py39haa95532_0
  • rfc3339-validator=0.1.4=py39haa95532_0
  • rfc3986-validator=0.1.1=py39haa95532_0
  • rpds-py=0.10.6=py39h062c2fa_0
  • send2trash=1.8.2=py39haa95532_0
  • setuptools=68.0.0=py39haa95532_0
  • sip=6.7.12=py39hd77b12b_0
  • six=1.16.0=pyhd3eb1b0_1
  • sniffio=1.2.0=py39haa95532_1
  • soupsieve=2.5=py39haa95532_0
  • sqlite=3.41.2=h2bbff1b_0
  • stack_data=0.2.0=pyhd3eb1b0_0
  • sympy=1.12=py39haa95532_0
  • tbb=2021.8.0=h59b6b97_0
  • terminado=0.17.1=py39haa95532_0
  • tinycss2=1.2.1=py39haa95532_0
  • tk=8.6.12=h2bbff1b_0
  • tomli=2.0.1=py39haa95532_0
  • tornado=6.3.3=py39h2bbff1b_0
  • traitlets=5.7.1=py39haa95532_0
  • typing-extensions=4.7.1=py39haa95532_0
  • typing_extensions=4.7.1=py39haa95532_0
  • tzdata=2023c=h04d1e81_0
  • vc=14.2=h21ff451_1
  • vs2015_runtime=14.27.29016=h5e58377_2
  • wcwidth=0.2.5=pyhd3eb1b0_0
  • webencodings=0.5.1=py39haa95532_1
  • websocket-client=0.58.0=py39haa95532_4
  • wheel=0.41.2=py39haa95532_0
  • win_inet_pton=1.1.0=py39haa95532_0
  • winpty=0.4.3=4
  • xz=5.4.5=h8cc25b3_0
  • yaml=0.2.5=he774522_0
  • zeromq=4.3.4=hd77b12b_0
  • zipp=3.11.0=py39haa95532_0
  • zlib=1.2.13=h8cc25b3_0
  • zstd=1.5.5=hd43e919_0
  • pip:
    • ase==3.22.1
    • bokeh==3.3.2
    • charset-normalizer==3.3.2
    • cython==3.0.6
    • decorator==4.4.2
    • dynetx==0.3.2
    • future==0.18.3
    • googledrivedownloader==0.4
    • gpytorch==1.11
    • idna==3.6
    • igraph==0.11.3
    • isodate==0.6.1
    • jaxtyping==0.2.25
    • joblib==1.3.2
    • linear-operator==0.5.2
    • llvmlite==0.41.1
    • ndlib==5.1.1
    • netdispatch==0.1.0
    • nomkl==0.0.3
    • numba==0.58.1
    • psutil==5.9.6
    • pyparsing==3.1.1
    • python-igraph==0.11.3
    • python-louvain==0.16
    • rdflib==7.0.0
    • scikit-learn==1.3.2
    • scipy==1.11.4
    • texttable==1.7.0
    • threadpoolctl==3.2.0
    • torch-cluster==1.6.3+pt21cpu
    • torch-geometric==2.4.0
    • torch-geometric-temporal==0.54.0
    • torch-scatter==2.1.2+pt21cpu
    • torch-sparse==0.6.18+pt21cpu
    • torch-spline-conv==1.2.2+pt21cpu
    • tqdm==4.66.1
    • typeguard==2.13.3
    • urllib3==2.1.0
    • xyzservices==2023.10.1 prefix: C:\Users\pvs\miniconda3\envs\predicting-GNNs

pauvilasoler avatar Jun 04 '24 13:06 pauvilasoler

What does data.validate() return for you?

rusty1s avatar Jun 06 '24 07:06 rusty1s

data.validate() returns True for every HeteroData object in the Dataset.

I was thinking maybe it could be an issue with how the edge indices are encoded as the indices for the alters go from 1 to 25 (and maybe it should be from 0 to 24).

As an example here is how the edge indices for the relationship ('Alter', 'to', 'Ego') look like:

[[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25], [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]]

pauvilasoler avatar Jun 06 '24 08:06 pauvilasoler

It was indeed this issue.

However, shouldn't data.validate() return False in cases like these where the indices are wrongly encoded?

Thanks a lot anyway!

pauvilasoler avatar Jun 10 '24 08:06 pauvilasoler

data.validate() just checks for invalid edges. It cannot automatically detect whether edges are semantically incorrect.

rusty1s avatar Jun 24 '24 10:06 rusty1s