ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
Make `ProjectTo` convert `Tangent` back to `Diagonal`, etc, when safe
The example from https://github.com/JuliaDiff/ChainRulesCore.jl/issues/441#issuecomment-902941440 is x -> sqrt(Diagonal(x)), whose implementation is sqrt.(x.diag). At present, Zygote returns a "structural" tangent for this, i.e. a NamedTuple. When the .diag is the very first operation being performed, this is returned, but if it occurs after other operations, then their gradients will tend not to understand this.
This PR proposes that there should be a method ProjectTo{Diagonal}(::Tangent) which converts this back to the "natural" form, i.e. to another Diagonal. To try it out:
# ] dev ChainRulesCore
using Zygote, ChainRulesCore, LinearAlgebra
gradient(x -> sum(sqrt.(x.diag)), Diagonal([1,2,3])) # returns a NamedTuple containing a vector.
gradient(x -> sum(sqrt.((cbrt.(x)).diag)), Diagonal([1,2,3])) # gradient of cbrt.(x) fails
# Instead of hacking the gradient for getproperty, for now insert this:
function ChainRulesCore.rrule(::typeof(identity), x::AbstractArray)
x, dx -> (NoTangent(), ProjectTo(x)(dx))
end
gradient(x -> sum(sqrt.(identity(x).diag)), Diagonal([1,2,3]))[1] # returns a Diagonal
gradient(x -> sum(sqrt.((identity(cbrt.(x))).diag)), Diagonal([1,2,3]))[1] # works!
# Note these print T = Any, from (::Tangent{T}) where T = (@show T; ...)
There ~~are~~ were two immediate hurdles here. ~~One is that, to work with Base's sqrt(::Diagonal) method, you would have to insert a projection step into Zygote's litereal_getproperty adjoint definition. It's not immediately obvious to me how to do that.~~ Done in https://github.com/FluxML/Zygote.jl/pull/1104
~~The second is that Zygote makes a Tangent{Any} here. I thought that dx::Tangent{T} was supposed to always have typeof(x) == T exactly. If that's not true, then must we worry about getting a Tangent which doesn't come from a Diagonal at all?~~ See below, I guess this wants https://github.com/FluxML/Zygote.jl/pull/1057
I added a similar line for UpperTriangular. At present x::UpperTriangular accepts dx::Diagonal as a "natural" gradient. If it must accept dx::Tangent{Diagonal} too, well we could write a method for that. But how generally that can work I don't know. I'm not sure it's worth trying.
Beyond the immediate, this is still an easy case of the problems discussed in #441 (not 411, sorry!), in that the Tangent contains a Vector and can trivially be re-wrapped to form a Diagonal. (Before finding this Any, I was trying to make dispatch restrict it to this case.) It doesn't address what should happen if the content of the Tangent{Diagonal} is some other weird structural thing, which cannot itself be ProjectTo'd to an AbstractVector -- that's what we don't have concrete examples of yet.
Edit -- this particular example is now handled explicitly by https://github.com/JuliaDiff/ChainRules.jl/pull/509.
Edit' -- the Tangent{Any} isn't an inference failure, it's explicitly constructed that way because the input NamedTuple doesn't have the type, here:
https://github.com/FluxML/Zygote.jl/blob/05d0c2ae04f334a2ec61e42decfe1172d0f2e6e8/src/compiler/chainrules.jl#L126-L129
Codecov Report
Merging #446 (2cd6f89) into main (a9840b4) will decrease coverage by
2.87%. The diff coverage is82.69%.
@@ Coverage Diff @@
## main #446 +/- ##
==========================================
- Coverage 93.03% 90.16% -2.88%
==========================================
Files 15 15
Lines 862 925 +63
==========================================
+ Hits 802 834 +32
- Misses 60 91 +31
| Impacted Files | Coverage Δ | |
|---|---|---|
| src/tangent_types/abstract_zero.jl | 84.61% <79.41%> (-10.84%) |
:arrow_down: |
| src/projection.jl | 89.41% <84.28%> (-7.89%) |
:arrow_down: |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact),ø = not affected,? = missing dataPowered by Codecov. Last update a9840b4...2cd6f89. Read the comment docs.
[Edited!]
The simplest effect is like so, although right now this works only if it accepts Tangent{Any}:
julia> Zygote.gradient(x -> parent(x)[1], Diagonal([1,2,3]))[1]
3×3 Diagonal{Float64, Vector{Float64}}:
1.0 ⋅ ⋅
⋅ 0.0 ⋅
⋅ ⋅ 0.0
Zygote is applying the projection at the last backward step above. But the point of this is really to apply it at the first step. ~~I think this can be done as follows:~~ This is done by https://github.com/FluxML/Zygote.jl/pull/1104 now.
With that, the gradient can propagate through for instance further broadcasting steps:
julia> gradient(x -> sum(sqrt.((cbrt.(x)).diag)), Diagonal([1,2,3]))[1]
3×3 Diagonal{Float64, Vector{Float64}}:
0.166667 ⋅ ⋅
⋅ 0.0935385 ⋅
⋅ ⋅ 0.0667187
Without this PR, Zygote would like to return a NamedTuple. Which cannot be handled by the gradient of broadcasting.
In fact the generic (::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx is not wide enough right now -- projection gives an error on Tangent{Any}:
julia> pullback(x -> parent(x)[1], Diagonal([1,2,3]))[2](1.0) # no projection
((diag = [1.0, 0.0, 0.0],),)
julia> gradient(x -> parent(x)[1], Diagonal([1,2,3]))[1] # with projection
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}})(::ChainRulesCore.Tangent{Any, NamedTuple{(:diag,), Tuple{Vector{Float64}}}})
That error is a problem with or without this PR. ~~Maybe it should be widened?~~ This should also fixed by https://github.com/FluxML/Zygote.jl/pull/1104, which uses zygote2differential to produce Tangent{T} instead of Tangent{Any}.
I am going to leave this to @willtebbutt to review. and unsubscribe. Ping me if i am needed
Spotted today, this is an example in the wild of what this PR wants to do:
https://github.com/SciML/DiffEqFlux.jl/blob/61cc51e1b63709d55655c53d50244cc3932bd60e/src/DiffEqFlux.jl#L71-L73
That's for Tridiagonal. Notice that it always has to make two zero Vectors on each call, which seems a bit unfortunate. And is probably an argument for allowing these to make Diagonal, Bidiagonal, subspaces.
OK I think this is done. Plays well with Zygote 6.30:
julia> gradient(x -> sqrt(sum((x .^ 2).ev)), Bidiagonal([1,2,3], [4,5], :U))[1]
3×3 Bidiagonal{Float64, Vector{Float64}}:
0.0 0.624695 ⋅
⋅ 0.0 0.780869
⋅ ⋅ 0.0
This was wrong before:
julia> gradient(x -> sum(abs, x), UnitUpperTriangular(rand(3,3)))[1]
3×3 UpperTriangular{Float64, Matrix{Float64}}:
0.0 1.0 1.0
⋅ 0.0 1.0
⋅ ⋅ 0.0
Two bugs I give up on for now:
julia> UpperTriangular(Fill(3,3,3)) - I # not my fault. Fixed in Julia 1.8
ERROR: ArgumentError: Cannot setindex! to 2 for an AbstractFill with value 3.
julia> pullback(x -> sqrt(sum((x .^ 2).dv)), Bidiagonal([1,2,3], [4,5], :U))[2](1)[1] # My code makes a Diagonal, I have no idea where the Tri comes from
3×3 Tridiagonal{Float64, Vector{Float64}}:
0.267261 0.0 ⋅
0.0 0.534522 0.0
⋅ 0.0 0.801784
Bump.
Sorry this is taking so long to review, It is just slightly too big (conceptually) for me to review during the time i normally have aside to review things, and so I keep starting to review it and not completing it.
Bump?
This is one of the last pieces of the ProjectTo story worked out last summer, but somehow hasn't made it. We'd have heard if its absence was holding anyone back, so it can't be that important, but it does clarify what the design is, a bit.
sorry this has been on my to do for a long time. It's not forgotten its just big and needs some thinking