Flux.jl
Flux.jl copied to clipboard
Add more `Duplicated` methods for Enzyme.jl support
This adds a method like gradient(f, ::Duplicated) which like train!(loss, model::Duplicated, data, opt) from https://github.com/FluxML/Flux.jl/pull/2446 uses the Duplicated type to signal that you want to use Enzyme not Zygote. It returns the gradient (for compatibility?) and mutates the Duplicated object.
-
To avoid piracy, this creates a new function
Flux.gradientwhich by default callsZygote.gradient. Unfortunately that's going to mean everyusing Flux, Zygotenow produces ambiguities... so probably it should not be exported? Which means 0.15. -
There's also
withgradientbut it doesn't allow you to return a tuple the way Zygote does, not yet. -
There's also a method of
update!which either needs to move to Optimisers.jl, or again we need to let Flux own the function. -
Finally,
@layer Chaindefines a 1-argumentDuplicated(c::Chain)method, so that you don't need to construct the dual by hand.
WIP, RFC?
Needs tests, and docs.
PR Checklist
- [ ] Tests are added
- [ ] Entry in NEWS.md
- [ ] Documentation, if applicable