Lux.jl
Lux.jl copied to clipboard
Auto compile Lux models to reactant
Example Usage
This follows the same structure as SimpleChains. User demands a conversion and provides an input prototype.
using Reactant, Lux, Random
model = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
reactant_model = ToReactantAdaptor{true}(rand(Float32, 10, 3); force_compile_backward=true)(model)
ps, st = Lux.setup(Random.default_rng(), reactant_model)
x = randn(Float32, 10, 3)
reactant_model(x, ps, st)
Upstream Needs
- [x] https://github.com/EnzymeAD/Reactant.jl/pull/7
- [ ] https://github.com/EnzymeAD/Reactant.jl/pull/5 (for the backward pass)
- [x] (Nice to have) https://github.com/EnzymeAD/Reactant.jl/issues/8 for extending the support to
LuxCore.apply
instead ofLuxCore.apply
TODOs
- [x] Compile Forward Pass
- [x] Compile Inference Pass
- [ ] Compile VJP (using Enzyme)
- [ ] Support the standard AD backends as well via ChainRules
- [ ] Add Enzyme Rules to directly call the compiled function
- [ ] Compile JVP (using Enzyme)
- [ ] Support ForwardDiff
- [ ] Add Enzyme ForwardDiff Rules
- [ ]
__make_reactant_array
- [ ] CuArrays
- [ ] https://github.com/EnzymeAD/Reactant.jl/issues/16
- [ ] ComponentArrays Special Handling
- [ ] Add documentation (make sure that users know that this is experimental)
- [ ] Add a tutorial similar to the SimpleChains one
- [ ] Full Training Pipeline using Training API
- [ ] Partial compilation for NN using XLA and remaining via LLVM (like Neural ODEs)
- [ ] Add Reactant to our benchmark suite
- [ ] Add tests
- [ ] CPU
- [ ] CUDA
- [x] (Nice to have) Extend to all models, not just stateless ones.
- [ ] Compile the training loop
- [x] Fallback implementation for existing backends
- [ ] Reactant Backend
- [ ] Add to documentation