`f/rrules` should support receiving `ZeroTangent()`
Similarly to https://github.com/JuliaDiff/ChainRules.jl/issues/408, ZeroTangent() is a valid input to the pullback, and we need to make sure it is supported.
Using https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/176, there are at least three kinds of errors:
- Pullbacks are written such that they do not support taking in
ZeroTangent(), e.g.MethodError: no method matching (::ChainRules.var"#transpose_pullback#1894")(::ZeroTangent). These just need to be fixed in ChainRules.jl - Places where we (I think?) have to
projecttheZeroTangent():TypeError: in Hermitian, in S, expected S<:(AbstractArray{var"#s832", 2} where var"#s832"<:T), got Type{ZeroTangent}and
MethodError: Cannot `convert` an object of type
ZeroTangent to an object of type
Matrix{T} where T
- Errors which could be solved by
projecting theZeroTangent()e.g.MethodError: no method matching getindex(::ZeroTangent, ::Int64). The question is whether we actually want to project to an array, since that would allocate quite a bit. Alternatively, we could defineBase.getindex(::ZeroTangent, args...) = ZeroTangent(). There might be quite a few of these functions to define, but it would be much faster. Some examples are:
MethodError: no method matching Complex(::ZeroTangent)MethodError: no method matching mapfoldl(::typeof(identity), ::typeof(Base.add_sum), ::ZeroTangent; dims=Colon())MethodError: no method matching tr(::ZeroTangent)MethodError: no method matching mul!(::ZeroTangent, ::Matrix{Float64}, ::ZeroTangent, ::Bool, ::Bool)MethodError: no method matching trsyl!(::Char, ::Char, ::Matrix{ComplexF64}, ::Matrix{ComplexF64}, ::ZeroTangent)MethodError: no method matching size(::ZeroTangent, ::Int64)MethodError: no method matching LowerTriangular(::ZeroTangent)and the list goes on
This is just a quick (incomplete) dump of observations and first thoughts. I may have missed kinds of errors, or said things which are untrue.
Alltogether:
Test Summary: | Pass Error Broken Total
ChainRules | 24330 429 4 24763
I think most of the oens under 3 should be solved via implementing them for ZeroTangent
without consideration for projecting.
While there are many they are finite in number.
and the answers are mostly obvious, because it is general ZeroTanget() because linear operators map zero to zero.
In some cases it isn't but we would be able to get through a lot of them, pretty quickly.
And it is the same set as for #408
Should every rrule have a method for this, or should it be handled at a higher level, by not calling the rule at all? The former seems like a lot of boilerplate. Maybe the answer depends on this:
In some cases it isn't
What cases are there where a zero tangent should become something nonzero?
The case we have to worry about is functions with multiple inputs, some of which are zero, and some of which are not.
Which as I say that, I realize makes no sense: pullbacks always have 1 input, because julia functions are all single output (it is just sometimes that output is an iterator). So yes, I think this can be handled at a higher level by not calling the pullback at all.
Yes, I guess if your function returns a tuple, then it may have to worry about getting back (Zero(), something). Maybe this one would work?
julia> y, b = pullback(findmax, rand(3));
julia> b((1,2))
([0.0, 1.0, 0.0],)
julia> b((pi,3))
([0.0, 3.141592653589793, 0.0],)
julia> b((pi,"nothing"))
ERROR: MethodError: no method matching +(::Int64, ::String)
[1] accum(x::Int64, y::String)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/lib/lib.jl:17
[2] macro expansion
@ ~/.julia/packages/Zygote/0da6K/src/lib/lib.jl:27 [inlined]
[3] accum(x::NamedTuple{(:first, :second), Tuple{Int64, Irrational{:π}}}, y::NamedTuple{(:first, :second), Tuple{String, Float64}})
@ Zygote ~/.julia/packages/Zygote/0da6K/src/lib/lib.jl:27
[4] accum
@ ~/.julia/packages/Zygote/0da6K/src/lib/lib.jl:17 [inlined]
[5] (::typeof(∂(#250)))(Δ::Tuple{Irrational{:π}, String})
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
[6] Pullback
@ ./reduce.jl:95 [inlined]
[7] (::typeof(∂(Base.MappingRF{Base.var"#250#251"{typeof(identity)}, Base.BottomRF{typeof(Base._rf_findmax)}}(Base.var"#250#251"{typeof(identity)}(identity), Base.BottomRF{typeof(Base._rf_findmax)}(Base._rf_findmax)))))(Δ::Tuple{Irrational{:π}, String})
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
(Zero(), something) should be Tangent{Tuple{Int, Float64}(ZeroTangent(), something).
Since Tuple is not a valid tangent type since it doesn't support zero or +.
Though a lot of our methods do let you pass in any iterator, when you really should only be allowed to pass in a Tangent{Tuple}.
(or some otehr iterator that overloads zero and + etc)
What do you mean by "handled at a higher level"? Like handled automatically by the AD system, or making rrule a macro which add a line for treating ZeroTangent automatically? Or is there another way?
Handled in the AD system before calling the pullback. like this line in Zygote https://github.com/FluxML/Zygote.jl/blob/1082ebd3aced63b99c4b6c2956a122ce6a37f97d/src/compiler/chainrules.jl#L94 and this is where we would change Nabla https://github.com/invenia/Nabla.jl/pull/189/files#r662148541