ForwardDiff.jl icon indicating copy to clipboard operation
ForwardDiff.jl copied to clipboard

Automatic ChainRules compatibility

Open gdalle opened this issue 2 years ago • 4 comments

In a recent Slack discussion, @mohamed82008 posted this useful code snippet that shouldn't go to waste. With a little bit of work, this could be turned into a macro that automatically translates a ChainRulesCore.frule into its ForwardDiff.Dual-compatible counterpart. Since ChainRulesCore is a very light dependency, would it make sense to include such a thing to ForwardDiff? Judging by the reactions on the Slack #autodiff channel, lots of people would find it useful.

using ChainRulesCore, ForwardDiff

macro ForwardDiff_frule(f)
	quote
		function $(esc(f))(x::Vector{<:ForwardDiff.Dual{T}}) where {T}
		  xv, Δx = ForwardDiff.value.(x), reduce(vcat, transpose.(ForwardDiff.partials.(x)))
		  out, Δf = ChainRulesCore.frule((NoTangent(), Δx), $(esc(f)), xv)
		  if out isa Real
			return ForwardDiff.Dual{T}(out, ForwardDiff.Partials(Tuple(Δf)))
		  elseif out isa Vector
			return ForwardDiff.Dual{T}.(out, ForwardDiff.Partials.(Tuple.(eachrow(Δf))))
		  else
		  	throw("Unsupported output.")
		  end
		end
	end
end

f1(x) = sum(x)
function ChainRulesCore.frule((_, Δx), ::typeof(f1), x::AbstractVector{<:Number})
  println("frule was used")
  return f1(x), sum(Δx, dims = 1)
end

f2(x) = x
function ChainRulesCore.frule((_, Δx), ::typeof(f2), x::AbstractVector{<:Number})
  println("frule was used")
  return f2(x), Δx
end

@ForwardDiff_frule f1
ForwardDiff.gradient(f1, rand(3))
# frule was used
# 3-element Vector{Float64}:
#  1.0
#  1.0
#  1.0

@ForwardDiff_frule f2
ForwardDiff.jacobian(f2, rand(3))
# frule was used
# 3×3 Matrix{Float64}:
#  1.0  0.0  0.0
#  0.0  1.0  0.0
#  0.0  0.0  1.0

gdalle avatar Mar 06 '22 21:03 gdalle

+1 I think this kind of ForwardDiff inter-operability with ChainRulesCore frules and even opt-in access to the many frules defined in ChainRules.jl would be a widely-appreciated update if at all possible.

Is avoiding dependency on ChainRulesCore/ChainRules the issue preventing this?

doddgray avatar Mar 07 '22 01:03 doddgray

For future reference, @mohamed82008 is on the move again: https://github.com/JuliaNonconvex/NonconvexUtils.jl/pull/6

gdalle avatar Mar 09 '22 10:03 gdalle

One way to make this fully automatic would be to move the most basic definitions of Dual, Partials etc. from here to ChainRulesCore. Then @scalar_rule f ... could automatically add methods for f(::Dual), in addition to its existing methods rrule(f, x) and frule, etc.

This may not be very difficult to do, although I haven't tried.

I don't know how other people feel about entangling these two packages. This one has slowly acquired quite a few dependencies (including CRC, indirectly); adding CRC directly may in fact make it lighter-weight, as those packages could (and probably already all do) define rules themselves using CRC.

mcabbott avatar Jun 28 '22 19:06 mcabbott

I am not keen on that particular solution. In general i want @scalar_rule to be less extra powerful.

oxinabox avatar Aug 09 '22 18:08 oxinabox

The solution from @mohamed82008 is a working macro now, and it is as easy as typing:

import NonconvexUtils

NonconvexUtils.@ForwardDiff_frule f1(x1::ForwardDiff.Dual, x2::ForwardDiff.Dual)
NonconvexUtils.@ForwardDiff_frule f1(x1::AbstractVector{<:ForwardDiff.Dual}, x2::AbstractVector{<:ForwardDiff.Dual})
NonconvexUtils.@ForwardDiff_frule f1(x1::AbstractMatrix{<:ForwardDiff.Dual}, x2::AbstractMatrix{<:ForwardDiff.Dual})

So you have ForwardDiff-dispatches for scalars, vectors and matrices for your function f1, based on an existing frule, see source code.

In my opinion this should definitely be added to ForwardDiff, would be a really nice thing.

ThummeTo avatar Oct 07 '22 09:10 ThummeTo

My implementation even works for structs if you pass in the constructor, see the tests. Struct support needs more infrastructure though compared to the simple vec/reshape needed for vector/matrix support. Even if the macro does not live in ForwardDiff, I would be happy if someone took it and put it in a separate light package and added a section in the ForwardDiff documentation pointing to the new package.

mohamed82008 avatar Oct 07 '22 17:10 mohamed82008

I mentioned to @gdalle before, every feature of NonconvexUtils should probably be its own package :)

mohamed82008 avatar Oct 07 '22 17:10 mohamed82008

I can seperate the function and put it in a dedicated package if you wish.

Are there any suggestions for names? It's basically something like the interface between ForwardDiff and the ChainRulesCore. That is what I was looking for on Google: How can I make ForwardDiff using frules. Or simply ForwardDiffFRule.jl ?

ThummeTo avatar Oct 10 '22 06:10 ThummeTo

ForwardDiffChainRules?

mohamed82008 avatar Oct 10 '22 08:10 mohamed82008

Ok, mini-package is coming. I will post it here!

ThummeTo avatar Oct 10 '22 08:10 ThummeTo

Find the repo here: ForwardDiffChainRules.jl

I tried to keep it light-weight, e.g. the SparseArrays- and JuMP-dispatches are only added if the corresponding libraries are (with Requires.jl). The rest is basically copy-paste (with author attribution). CI-Test currently fails (I will check this soon or maybe @mohamed82008 has a clue). As soon this is fixed and @mohamed82008 gives his OK, I can register the version.

Regards!

ThummeTo avatar Oct 10 '22 13:10 ThummeTo

I can take a look in a few days. Sorry, a bit busy the coming couple of days.

mohamed82008 avatar Oct 10 '22 16:10 mohamed82008

@ThummeTo can I get an invite to the repo? I fixed it locally and used https://github.com/JuliaNonconvex/DifferentiableFlatten.jl. I can open a PR or push directly to master.

mohamed82008 avatar Oct 17 '22 16:10 mohamed82008

Of course, you should have an invitation now. Much appreciated!

ThummeTo avatar Oct 18 '22 06:10 ThummeTo

So it's ready for a first release I guess @mohamed82008 ?

ThummeTo avatar Oct 18 '22 07:10 ThummeTo

Yes

mohamed82008 avatar Oct 18 '22 07:10 mohamed82008

Since we have a sufficiently lightweight package that implements this feature now, should we close this issue @gdalle?

mohamed82008 avatar Oct 18 '22 07:10 mohamed82008

Can we document this in the docs here (and also ChainRulesCore.jl)

oxinabox avatar Oct 18 '22 11:10 oxinabox

Since we have a sufficiently lightweight package that implements this feature now, should we close this issue @gdalle?

Let's close it once this is referenced in the FD and CRC docs?

gdalle avatar Oct 18 '22 17:10 gdalle

Both links were added, closing this

gdalle avatar Nov 17 '22 20:11 gdalle