escnn_jax
escnn_jax copied to clipboard
Equivariant Steerable CNNs Library for Pytorch https://quva-lab.github.io/escnn/
E(n)-equivariant Steerable CNNs (escnn)
Documentation | escnn library |
:rocket: ~20% faster than pytorch*
escnn_jax is a Jax port of the PyTorch escnn library for equivariant deep learning. escnn_jax supports steerable CNNs equivariant to both 2D and 3D isometries, as well as equivariant MLPs.
The library is structured into four subpackages with different high-level features:
| Component | Dependency | Description |
|---|---|---|
| escnn.group | Pure Python |
implements basic concepts of group and representation theory |
| escnn.gspaces | Pure Python |
defines the Euclidean spaces and their symmetries |
| escnn.kernels | Jax |
solves for spaces of equivariant convolution kernels |
| escnn.nn | Equinox |
contains equivariant modules to build deep neural networks |
TODOs
Priority
- reproduce examples and baselines
- [x]
mlp.ipynb- [ ] appart for
IIDBatchNorm1dmodule
- [ ] appart for
- [x]
introduction.ipynb - [ ]
model.ipynb - [ ]
octahedral_cnn.ipynb
- [x]
- [x] mimic
requires_grad=falsefor 'buffer' variables to avoid including them inopt_stateandgrads- added in
EquivariantModulethe methodsset_bufferandget_bufferwhich wrap the variable inlax.stop_gradient - added in
EquivariantModulethe methodsset_parameterandget_parameterwhich wrap the Array a custom typeescn_jax.nn.ParameterArraywhich can later be used to filter the parameters
- added in
- [ ] enhance
model.eval()behaviour; makeEquivariantModule.evalrecursively call submodules? - [ ] speed up module's
__init__e.g.nn.Linearandnn.R2Conv - [ ] speed up module's
__call__if possible? - [ ] better
__repr__forEquivariantModuleandeqx.nn.Modulemore generally - [ ] make sure that tests pass for implemented modules and kernels
- [ ] Bug?
InnerBatchNorm.eval()without training returns high values - [ ] add
exportmethod for layers - [ ] properly measuring speed up wrt pytorch version
Nice to have
- [ ] add support for
haiku/flaxunderescnn.nn.haiku/escnn.nn.flax - [ ]
jaxlinopforRepresentationclass akin toemlp, and more generally rewriteescnn_jax.groupinjax? - [ ] add missing modules cf
/nn/__init__.py
Getting Started
escnn_jax is easy to use since it provides a high level user interface which abstracts most intricacies of group and representation theory away. The following code snippet shows how to perform an equivariant convolution from an RGB-image to 10 regular feature fields (corresponding to a group convolution).
from escnn_jax import gspaces # 1
from escnn_jax import nn # 2
import jax # 3
key = jax.random.PRNGKey(0) # 4
key1, key2 = jax.random.split(key, 2) # 5
# 6
r2_act = gspaces.rot2dOnR2(N=8) # 7
feat_type_in = nn.FieldType(r2_act, 3*[r2_act.trivial_repr]) # 8
feat_type_out = nn.FieldType(r2_act, 10*[r2_act.regular_repr]) # 9
# 10
conv = nn.R2Conv(feat_type_in, feat_type_out, kernel_size=5, key=key1) # 11
relu = nn.ReLU(feat_type_out) # 12
# 13
x = jax.random.normal(key2, (16, 3, 32, 32)) # 14
x = feat_type_in(x) # 15
# 16
y = relu(conv(x)) # 17
Dependencies
The library is based on Python3.7
jax
equinox
jaxtyping
numpy
scipy
lie_learn
joblibx
py3nj
Optional:
pymanopt>=1.0.0
optax
chex
WARNING:
py3njenables a fast computation of Clebsh Gordan coefficients. If this package is not installed, our library relies on a numerical method to estimate them. This numerical method is not guaranteed to return the same coefficients computed bypy3nj(they can differ by a sign). For this reason, models built with and withoutpy3njmight not be compatible.
To successfully install
py3njyou may need a Fortran compiler installed in you environment.
Installation
You can install the latest release as
pip install escnn_jax
or you can clone this repository and manually install it with
pip install git+https://github.com/QUVA-Lab/escnn_jax
Contributing
Would you like to contribute to escnn_jax? That's great!
Then, check the instructions in CONTRIBUTING.md and help us to improve the library!
Cite
The development of this library was part of the work done for our papers A Program to Build E(N)-Equivariant Steerable CNNs and General E(2)-Equivariant Steerable CNNs. Please cite these works if you use our code:
@inproceedings{cesa2022a,
title={A Program to Build {E(N)}-Equivariant Steerable {CNN}s },
author={Gabriele Cesa and Leon Lang and Maurice Weiler},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=WE4qe9xlnQw}
}
@inproceedings{e2cnn,
title={{General E(2)-Equivariant Steerable CNNs}},
author={Weiler, Maurice and Cesa, Gabriele},
booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
year={2019},
}
Feel free to contact us.
License
escnn_jax is distributed under BSD Clear license. See LICENSE file.