DeepEquilibriumNetworks.jl
DeepEquilibriumNetworks.jl copied to clipboard
Implicit Layer Machine Learning via Deep Equilibrium Networks, O(1) backpropagation with accelerated convergence.
DeepEquilibriumNetworks
DeepEquilibriumNetworks.jl is a framework built on top of DifferentialEquations.jl and Lux.jl enabling the efficient training and inference for Deep Equilibrium Networks (Infinitely Deep Neural Networks).
Installation
using Pkg
Pkg.add("DeepEquilibriumNetworks")
Quickstart
using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote, SciMLSensitivity
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support. See https://lux.csail.mit.edu/dev/manual/gpu_management
seed = 0
rng = Random.default_rng()
Random.seed!(rng, seed)
model = Chain(Dense(2 => 2),
DeepEquilibriumNetwork(
Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)),
NewtonRaphson()))
gdev = gpu_device()
cdev = cpu_device()
ps, st = Lux.setup(rng, model) |> gdev
x = rand(rng, Float32, 2, 3) |> gdev
y = rand(rng, Float32, 2, 3) |> gdev
model(x, ps, st)
gs = only(Zygote.gradient(p -> sum(abs2, first(model(x, p, st)) .- y), ps))
Citation
If you are using this project for research or other academic purposes consider citing our paper:
@article{pal2022continuous,
title={Continuous Deep Equilibrium Models: Training Neural ODEs Faster by Integrating Them to Infinity},
author={Pal, Avik and Edelman, Alan and Rackauckas, Christopher},
booktitle={2023 IEEE High Performance Extreme Computing Conference (HPEC)},
year={2023}
}
For specific algorithms, check the respective documentations and cite the corresponding papers.