neural-graphs
neural-graphs copied to clipboard
Official source code for "Graph Neural Networks for Learning Equivariant Representations of Neural Networks". In ICLR 2024 (oral).
Graph Neural Networks for Learning Equivariant Representations of Neural Networks
Official implementation for
Graph Neural Networks for Learning Equivariant Representations of Neural Networks Miltiadis Kofinas*, Boris Knyazev, Yan Zhang, Yunlu Chen, Gertjan J. Burghouts, Efstratios Gavves, Cees G. M. Snoek, David W. Zhang* ICLR 2024 https://arxiv.org/abs/2403.12143/ *Joint first and last authors
Setup environment
To run the experiments, first create a clean virtual environment and install the requirements.
conda create -n neural-graphs python=3.9
conda activate neural-graphs
conda install pytorch==2.0.1 torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
conda install pyg==2.3.0 pytorch-scatter -c pyg
pip install hydra-core einops opencv-python
Install the repo:
git clone https://https://github.com/mkofinas/neural-graphs.git
cd neural-graphs
pip install -e .
Introduction Notebook
An introduction notebook for INR classification with Neural Graphs:
Run experiments
To run a specific experiment, please follow the instructions in the README file within each experiment folder. It provides full instructions and details for downloading the data and reproducing the results reported in the paper.
- INR classification:
experiments/inr_classification - INR style editing:
experiments/style_editing - CNN generalization:
experiments/cnn_generalization - Learning to optimize (coming soon):
experiments/learning_to_optimize
Datasets
INR classification and style editing
For INR classification, we use MNIST and Fashion MNIST. The datasets are available here.
For INR style editing, we use MNIST. The dataset is available here.
CNN generalization
For CNN generalization, we use the grayscale CIFAR-10 (CIFAR10-GS) from the Small CNN Zoo dataset. We also introduce CNN Wild Park, a dataset of CNNs with varying numbers of layers, kernel sizes, activation functions, and residual connections between arbitrary layers.
- CIFAR10-GS
- CNN Wild Park (coming soon)
Citation
If you find our work or this code to be useful in your own research, please consider citing the following paper:
@inproceedings{kofinas2024graph,
title={{G}raph {N}eural {N}etworks for {L}earning {E}quivariant {R}epresentations of {N}eural {N}etworks},
author={Kofinas, Miltiadis and Knyazev, Boris and Zhang, Yan and Chen, Yunlu and Burghouts,
Gertjan J. and Gavves, Efstratios and Snoek, Cees G. M. and Zhang, David W.},
booktitle = {12th International Conference on Learning Representations ({ICLR})},
year={2024}
}
@inproceedings{zhang2023neural,
title={{N}eural {N}etworks {A}re {G}raphs! {G}raph {N}eural {N}etworks for {E}quivariant {P}rocessing of {N}eural {N}etworks},
author={Zhang, David W. and Kofinas, Miltiadis and Zhang, Yan and Chen, Yunlu and Burghouts, Gertjan J. and Snoek, Cees G. M.},
booktitle = {Workshop on Topology, Algebra, and Geometry in Machine Learning (TAG-ML), ICML},
year={2023}
}
Acknowledgments
- This codebase started based on github.com/AvivNavon/DWSNets and the DWSNet implementation is copied from there
- The NFN implementation is copied and slightly adapted from github.com/AllanYangZhou/nfn
- We implemented the relational transformer in PyTorch following the JAX implementation at github.com/CameronDiao/relational-transformer. Our implementation has some differences that we describe in the paper.