Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Add more `Duplicated` methods for Enzyme.jl support

Open mcabbott opened this issue 1 year ago • 0 comments

This adds a method like gradient(f, ::Duplicated) which like train!(loss, model::Duplicated, data, opt) from https://github.com/FluxML/Flux.jl/pull/2446 uses the Duplicated type to signal that you want to use Enzyme not Zygote. It returns the gradient (for compatibility?) and mutates the Duplicated object.

  • To avoid piracy, this creates a new function Flux.gradient which by default calls Zygote.gradient. Unfortunately that's going to mean every using Flux, Zygote now produces ambiguities... so probably it should not be exported? Which means 0.15.

  • There's also withgradient but it doesn't allow you to return a tuple the way Zygote does, not yet.

  • There's also a method of update! which either needs to move to Optimisers.jl, or again we need to let Flux own the function.

  • Finally, @layer Chain defines a 1-argument Duplicated(c::Chain) method, so that you don't need to construct the dual by hand.

WIP, RFC?

Needs tests, and docs.

PR Checklist

  • [ ] Tests are added
  • [ ] Entry in NEWS.md
  • [ ] Documentation, if applicable

mcabbott avatar Jul 25 '24 03:07 mcabbott