ForwardDiff.jl
ForwardDiff.jl copied to clipboard
Automatic ChainRules compatibility
In a recent Slack discussion, @mohamed82008 posted this useful code snippet that shouldn't go to waste.
With a little bit of work, this could be turned into a macro that automatically translates a ChainRulesCore.frule
into its ForwardDiff.Dual
-compatible counterpart.
Since ChainRulesCore is a very light dependency, would it make sense to include such a thing to ForwardDiff? Judging by the reactions on the Slack #autodiff channel, lots of people would find it useful.
using ChainRulesCore, ForwardDiff
macro ForwardDiff_frule(f)
quote
function $(esc(f))(x::Vector{<:ForwardDiff.Dual{T}}) where {T}
xv, Δx = ForwardDiff.value.(x), reduce(vcat, transpose.(ForwardDiff.partials.(x)))
out, Δf = ChainRulesCore.frule((NoTangent(), Δx), $(esc(f)), xv)
if out isa Real
return ForwardDiff.Dual{T}(out, ForwardDiff.Partials(Tuple(Δf)))
elseif out isa Vector
return ForwardDiff.Dual{T}.(out, ForwardDiff.Partials.(Tuple.(eachrow(Δf))))
else
throw("Unsupported output.")
end
end
end
end
f1(x) = sum(x)
function ChainRulesCore.frule((_, Δx), ::typeof(f1), x::AbstractVector{<:Number})
println("frule was used")
return f1(x), sum(Δx, dims = 1)
end
f2(x) = x
function ChainRulesCore.frule((_, Δx), ::typeof(f2), x::AbstractVector{<:Number})
println("frule was used")
return f2(x), Δx
end
@ForwardDiff_frule f1
ForwardDiff.gradient(f1, rand(3))
# frule was used
# 3-element Vector{Float64}:
# 1.0
# 1.0
# 1.0
@ForwardDiff_frule f2
ForwardDiff.jacobian(f2, rand(3))
# frule was used
# 3×3 Matrix{Float64}:
# 1.0 0.0 0.0
# 0.0 1.0 0.0
# 0.0 0.0 1.0
+1
I think this kind of ForwardDiff inter-operability with ChainRulesCore frule
s and even opt-in access to the many frule
s defined in ChainRules.jl would be a widely-appreciated update if at all possible.
Is avoiding dependency on ChainRulesCore/ChainRules the issue preventing this?
For future reference, @mohamed82008 is on the move again: https://github.com/JuliaNonconvex/NonconvexUtils.jl/pull/6
One way to make this fully automatic would be to move the most basic definitions of Dual
, Partials
etc. from here to ChainRulesCore. Then @scalar_rule f ...
could automatically add methods for f(::Dual)
, in addition to its existing methods rrule(f, x)
and frule
, etc.
This may not be very difficult to do, although I haven't tried.
I don't know how other people feel about entangling these two packages. This one has slowly acquired quite a few dependencies (including CRC, indirectly); adding CRC directly may in fact make it lighter-weight, as those packages could (and probably already all do) define rules themselves using CRC.
I am not keen on that particular solution. In general i want @scalar_rule
to be less extra powerful.
The solution from @mohamed82008 is a working macro now, and it is as easy as typing:
import NonconvexUtils
NonconvexUtils.@ForwardDiff_frule f1(x1::ForwardDiff.Dual, x2::ForwardDiff.Dual)
NonconvexUtils.@ForwardDiff_frule f1(x1::AbstractVector{<:ForwardDiff.Dual}, x2::AbstractVector{<:ForwardDiff.Dual})
NonconvexUtils.@ForwardDiff_frule f1(x1::AbstractMatrix{<:ForwardDiff.Dual}, x2::AbstractMatrix{<:ForwardDiff.Dual})
So you have ForwardDiff-dispatches for scalars, vectors and matrices for your function f1
, based on an existing frule, see source code.
In my opinion this should definitely be added to ForwardDiff, would be a really nice thing.
My implementation even works for structs if you pass in the constructor, see the tests. Struct support needs more infrastructure though compared to the simple vec/reshape needed for vector/matrix support. Even if the macro does not live in ForwardDiff, I would be happy if someone took it and put it in a separate light package and added a section in the ForwardDiff documentation pointing to the new package.
I mentioned to @gdalle before, every feature of NonconvexUtils should probably be its own package :)
I can seperate the function and put it in a dedicated package if you wish.
Are there any suggestions for names?
It's basically something like the interface between ForwardDiff and the ChainRulesCore. That is what I was looking for on Google: How can I make ForwardDiff using frules. Or simply ForwardDiffFRule.jl
?
ForwardDiffChainRules?
Ok, mini-package is coming. I will post it here!
Find the repo here: ForwardDiffChainRules.jl
I tried to keep it light-weight, e.g. the SparseArrays- and JuMP-dispatches are only added if the corresponding libraries are (with Requires.jl). The rest is basically copy-paste (with author attribution). CI-Test currently fails (I will check this soon or maybe @mohamed82008 has a clue). As soon this is fixed and @mohamed82008 gives his OK, I can register the version.
Regards!
I can take a look in a few days. Sorry, a bit busy the coming couple of days.
@ThummeTo can I get an invite to the repo? I fixed it locally and used https://github.com/JuliaNonconvex/DifferentiableFlatten.jl. I can open a PR or push directly to master.
Of course, you should have an invitation now. Much appreciated!
So it's ready for a first release I guess @mohamed82008 ?
Yes
Since we have a sufficiently lightweight package that implements this feature now, should we close this issue @gdalle?
Can we document this in the docs here (and also ChainRulesCore.jl)
Since we have a sufficiently lightweight package that implements this feature now, should we close this issue @gdalle?
Let's close it once this is referenced in the FD and CRC docs?
Both links were added, closing this