graph_classification
graph_classification copied to clipboard
Benchmarking GNNs with PyTorch Lightning: Open Graph Benchmarks and image classification from superpixels
Description
This repository is supposed to be a place for curated, high quality benchmarks of Graph Neural Networks, implemented with PyTorch Lightning and Hydra.
Only datasets big enough to provide good measures are taken into consideration.
Built with lightning-hydra-template.
Datasets
- Open Graph Benchmarks (graph property prediction)
- Image classification from superpixels (MNIST, FashionMNIST, CIFAR10)
How to run
Install dependencies
# clone project
git clone https://github.com/ashleve/graph_classification
cd graph_classification
# [OPTIONAL] create conda environment
conda create -n myenv python=3.8
conda activate myenv
# install pytorch and pytorch geometric according to instructions
# https://pytorch.org/get-started/
# https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html
# install requirements
pip install -r requirements.txt
Train model with default configuration
# train on CPU
python run.py trainer.gpus=0
# train on GPU
python run.py trainer.gpus=1
Train model with chosen experiment configuration from configs/experiment/
python run.py experiment=GAT/gat_ogbg_molpcba
python run.py experiment=GraphSAGE/graphsage_mnist_sp75
python run.py experiment=GraphSAGE/graphsage_cifar10_sp100
You can override any parameter from command line like this
python run.py trainer.max_epochs=20 datamodule.batch_size=64
Methodology
For each experiment, we run a series of 10 random hparams runs, and 5 optimization runs, using Optuna bayesian sampler. The hyperparameter search configs are available under configs/hparams_search.
After finding best hyperparameters, each experiment was repeated 5 times with different random seeds. The only exception are the ogbg-molhiv
experiments, which were repeated 10 times each (because of high varience of results).
The results were averaged and reported in the table below.
Results
Architecture | MNIST-sp75 | FashionMNIST-sp75 | CIFAR10-sp100 | ogbg-molhiv | ogbg-molcpba |
---|---|---|---|---|---|
GCN | 0.955 ± 0.014 | 0.835 ± 0.016 | 0.518 ± 0.007 | 0.755 ± 0.019 | 0.231 ± 0.003 |
GIN | 0.966 ± 0.008 | 0.861 ± 0.012 | 0.512 ± 0.020 | 0.757 ± 0.025 | 0.240 ± 0.001 |
GAT | 0.976 ± 0.008 | 0.889 ± 0.003 | 0.617 ± 0.005 | 0.751 ± 0.026 | 0.234 ± 0.003 |
GraphSAGE | 0.981 ± 0.005 | 0.897 ± 0.012 | 0.629 ± 0.012 | 0.761 ± 0.025 | 0.256 ± 0.004 |
The +-
denotes standard deviation across all seeds.