Lux.jl icon indicating copy to clipboard operation
Lux.jl copied to clipboard

Auto compile Lux models to reactant

Open avik-pal opened this issue 8 months ago • 20 comments

Example Usage

This follows the same structure as SimpleChains. User demands a conversion and provides an input prototype.

using Reactant, Lux, Random

model = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)

reactant_model = ToReactantAdaptor{true}(rand(Float32, 10, 3); force_compile_backward=true)(model)
ps, st = Lux.setup(Random.default_rng(), reactant_model)

x = randn(Float32, 10, 3)

reactant_model(x, ps, st)

Upstream Needs

  • [x] https://github.com/EnzymeAD/Reactant.jl/pull/7
  • [ ] https://github.com/EnzymeAD/Reactant.jl/pull/5 (for the backward pass)
  • [x] (Nice to have) https://github.com/EnzymeAD/Reactant.jl/issues/8 for extending the support to LuxCore.apply instead of LuxCore.apply

TODOs

  • [x] Compile Forward Pass
  • [x] Compile Inference Pass
  • [ ] Compile VJP (using Enzyme)
    • [ ] Support the standard AD backends as well via ChainRules
    • [ ] Add Enzyme Rules to directly call the compiled function
  • [ ] Compile JVP (using Enzyme)
    • [ ] Support ForwardDiff
    • [ ] Add Enzyme ForwardDiff Rules
  • [ ] __make_reactant_array
    • [ ] CuArrays
    • [ ] https://github.com/EnzymeAD/Reactant.jl/issues/16
  • [ ] ComponentArrays Special Handling
  • [ ] Add documentation (make sure that users know that this is experimental)
  • [ ] Add a tutorial similar to the SimpleChains one
    • [ ] Full Training Pipeline using Training API
    • [ ] Partial compilation for NN using XLA and remaining via LLVM (like Neural ODEs)
  • [ ] Add Reactant to our benchmark suite
  • [ ] Add tests
    • [ ] CPU
    • [ ] CUDA
  • [x] (Nice to have) Extend to all models, not just stateless ones.
  • [ ] Compile the training loop
    • [x] Fallback implementation for existing backends
    • [ ] Reactant Backend
    • [ ] Add to documentation

avik-pal avatar May 27 '24 03:05 avik-pal