ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

Make `ProjectTo` convert `Tangent` back to `Diagonal`, etc, when safe

Open mcabbott opened this issue 4 years ago • 9 comments

The example from https://github.com/JuliaDiff/ChainRulesCore.jl/issues/441#issuecomment-902941440 is x -> sqrt(Diagonal(x)), whose implementation is sqrt.(x.diag). At present, Zygote returns a "structural" tangent for this, i.e. a NamedTuple. When the .diag is the very first operation being performed, this is returned, but if it occurs after other operations, then their gradients will tend not to understand this.

This PR proposes that there should be a method ProjectTo{Diagonal}(::Tangent) which converts this back to the "natural" form, i.e. to another Diagonal. To try it out:

# ] dev ChainRulesCore
using Zygote, ChainRulesCore, LinearAlgebra

gradient(x -> sum(sqrt.(x.diag)), Diagonal([1,2,3]))  # returns a NamedTuple containing a vector.
gradient(x -> sum(sqrt.((cbrt.(x)).diag)), Diagonal([1,2,3]))  # gradient of cbrt.(x) fails 

# Instead of hacking the gradient for getproperty, for now insert this:
function ChainRulesCore.rrule(::typeof(identity), x::AbstractArray)
  x, dx -> (NoTangent(), ProjectTo(x)(dx))
end

gradient(x -> sum(sqrt.(identity(x).diag)), Diagonal([1,2,3]))[1]  # returns a Diagonal
gradient(x -> sum(sqrt.((identity(cbrt.(x))).diag)), Diagonal([1,2,3]))[1]  # works!
# Note these print T = Any, from (::Tangent{T}) where T = (@show T; ...)

There ~~are~~ were two immediate hurdles here. ~~One is that, to work with Base's sqrt(::Diagonal) method, you would have to insert a projection step into Zygote's litereal_getproperty adjoint definition. It's not immediately obvious to me how to do that.~~ Done in https://github.com/FluxML/Zygote.jl/pull/1104

~~The second is that Zygote makes a Tangent{Any} here. I thought that dx::Tangent{T} was supposed to always have typeof(x) == T exactly. If that's not true, then must we worry about getting a Tangent which doesn't come from a Diagonal at all?~~ See below, I guess this wants https://github.com/FluxML/Zygote.jl/pull/1057

I added a similar line for UpperTriangular. At present x::UpperTriangular accepts dx::Diagonal as a "natural" gradient. If it must accept dx::Tangent{Diagonal} too, well we could write a method for that. But how generally that can work I don't know. I'm not sure it's worth trying.

Beyond the immediate, this is still an easy case of the problems discussed in #441 (not 411, sorry!), in that the Tangent contains a Vector and can trivially be re-wrapped to form a Diagonal. (Before finding this Any, I was trying to make dispatch restrict it to this case.) It doesn't address what should happen if the content of the Tangent{Diagonal} is some other weird structural thing, which cannot itself be ProjectTo'd to an AbstractVector -- that's what we don't have concrete examples of yet.

Edit -- this particular example is now handled explicitly by https://github.com/JuliaDiff/ChainRules.jl/pull/509.

Edit' -- the Tangent{Any} isn't an inference failure, it's explicitly constructed that way because the input NamedTuple doesn't have the type, here: https://github.com/FluxML/Zygote.jl/blob/05d0c2ae04f334a2ec61e42decfe1172d0f2e6e8/src/compiler/chainrules.jl#L126-L129

mcabbott avatar Aug 22 '21 15:08 mcabbott

Codecov Report

Merging #446 (2cd6f89) into main (a9840b4) will decrease coverage by 2.87%. The diff coverage is 82.69%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #446      +/-   ##
==========================================
- Coverage   93.03%   90.16%   -2.88%     
==========================================
  Files          15       15              
  Lines         862      925      +63     
==========================================
+ Hits          802      834      +32     
- Misses         60       91      +31     
Impacted Files Coverage Δ
src/tangent_types/abstract_zero.jl 84.61% <79.41%> (-10.84%) :arrow_down:
src/projection.jl 89.41% <84.28%> (-7.89%) :arrow_down:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update a9840b4...2cd6f89. Read the comment docs.

codecov-commenter avatar Aug 22 '21 15:08 codecov-commenter

[Edited!]

The simplest effect is like so, although right now this works only if it accepts Tangent{Any}:

julia> Zygote.gradient(x -> parent(x)[1], Diagonal([1,2,3]))[1]
3×3 Diagonal{Float64, Vector{Float64}}:
 1.0   ⋅    ⋅ 
  ⋅   0.0   ⋅ 
  ⋅    ⋅   0.0

Zygote is applying the projection at the last backward step above. But the point of this is really to apply it at the first step. ~~I think this can be done as follows:~~ This is done by https://github.com/FluxML/Zygote.jl/pull/1104 now.

With that, the gradient can propagate through for instance further broadcasting steps:

julia> gradient(x -> sum(sqrt.((cbrt.(x)).diag)), Diagonal([1,2,3]))[1]
3×3 Diagonal{Float64, Vector{Float64}}:
 0.166667   ⋅          ⋅ 
  ⋅        0.0935385   ⋅ 
  ⋅         ⋅         0.0667187

Without this PR, Zygote would like to return a NamedTuple. Which cannot be handled by the gradient of broadcasting.

In fact the generic (::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx is not wide enough right now -- projection gives an error on Tangent{Any}:

julia> pullback(x -> parent(x)[1], Diagonal([1,2,3]))[2](1.0)  # no projection
((diag = [1.0, 0.0, 0.0],),)

julia> gradient(x -> parent(x)[1], Diagonal([1,2,3]))[1]  # with projection
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}})(::ChainRulesCore.Tangent{Any, NamedTuple{(:diag,), Tuple{Vector{Float64}}}})

That error is a problem with or without this PR. ~~Maybe it should be widened?~~ This should also fixed by https://github.com/FluxML/Zygote.jl/pull/1104, which uses zygote2differential to produce Tangent{T} instead of Tangent{Any}.

mcabbott avatar Oct 16 '21 00:10 mcabbott

I am going to leave this to @willtebbutt to review. and unsubscribe. Ping me if i am needed

oxinabox avatar Oct 18 '21 12:10 oxinabox

Spotted today, this is an example in the wild of what this PR wants to do:

https://github.com/SciML/DiffEqFlux.jl/blob/61cc51e1b63709d55655c53d50244cc3932bd60e/src/DiffEqFlux.jl#L71-L73

That's for Tridiagonal. Notice that it always has to make two zero Vectors on each call, which seems a bit unfortunate. And is probably an argument for allowing these to make Diagonal, Bidiagonal, subspaces.

mcabbott avatar Nov 03 '21 13:11 mcabbott

OK I think this is done. Plays well with Zygote 6.30:

julia> gradient(x -> sqrt(sum((x .^ 2).ev)), Bidiagonal([1,2,3], [4,5], :U))[1]
3×3 Bidiagonal{Float64, Vector{Float64}}:
 0.0  0.624695   ⋅ 
  ⋅   0.0       0.780869
  ⋅    ⋅        0.0

This was wrong before:

julia> gradient(x -> sum(abs, x), UnitUpperTriangular(rand(3,3)))[1]
3×3 UpperTriangular{Float64, Matrix{Float64}}:
 0.0  1.0  1.0
  ⋅   0.0  1.0
  ⋅    ⋅   0.0

Two bugs I give up on for now:

julia> UpperTriangular(Fill(3,3,3)) - I  # not my fault. Fixed in Julia 1.8
ERROR: ArgumentError: Cannot setindex! to 2 for an AbstractFill with value 3.

julia> pullback(x -> sqrt(sum((x .^ 2).dv)), Bidiagonal([1,2,3], [4,5], :U))[2](1)[1]  # My code makes a Diagonal, I have no idea where the Tri comes from
3×3 Tridiagonal{Float64, Vector{Float64}}:
 0.267261  0.0        ⋅ 
 0.0       0.534522  0.0
  ⋅        0.0       0.801784

mcabbott avatar Nov 08 '21 00:11 mcabbott

Bump.

mcabbott avatar Jan 03 '22 17:01 mcabbott

Sorry this is taking so long to review, It is just slightly too big (conceptually) for me to review during the time i normally have aside to review things, and so I keep starting to review it and not completing it.

oxinabox avatar Mar 02 '22 17:03 oxinabox

Bump?

This is one of the last pieces of the ProjectTo story worked out last summer, but somehow hasn't made it. We'd have heard if its absence was holding anyone back, so it can't be that important, but it does clarify what the design is, a bit.

mcabbott avatar Jun 07 '22 15:06 mcabbott

sorry this has been on my to do for a long time. It's not forgotten its just big and needs some thinking

oxinabox avatar Jun 17 '22 15:06 oxinabox