ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

`f/rrules` should support receiving `ZeroTangent()`

Open mzgubic opened this issue 4 years ago • 7 comments

Similarly to https://github.com/JuliaDiff/ChainRules.jl/issues/408, ZeroTangent() is a valid input to the pullback, and we need to make sure it is supported.

Using https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/176, there are at least three kinds of errors:

  1. Pullbacks are written such that they do not support taking in ZeroTangent(), e.g. MethodError: no method matching (::ChainRules.var"#transpose_pullback#1894")(::ZeroTangent). These just need to be fixed in ChainRules.jl
  2. Places where we (I think?) have to project the ZeroTangent(): TypeError: in Hermitian, in S, expected S<:(AbstractArray{var"#s832", 2} where var"#s832"<:T), got Type{ZeroTangent} and
MethodError: Cannot `convert` an object of type 
    ZeroTangent to an object of type 
    Matrix{T} where T
  1. Errors which could be solved by projecting the ZeroTangent() e.g. MethodError: no method matching getindex(::ZeroTangent, ::Int64). The question is whether we actually want to project to an array, since that would allocate quite a bit. Alternatively, we could define Base.getindex(::ZeroTangent, args...) = ZeroTangent(). There might be quite a few of these functions to define, but it would be much faster. Some examples are:
  • MethodError: no method matching Complex(::ZeroTangent)
  • MethodError: no method matching mapfoldl(::typeof(identity), ::typeof(Base.add_sum), ::ZeroTangent; dims=Colon())
  • MethodError: no method matching tr(::ZeroTangent)
  • MethodError: no method matching mul!(::ZeroTangent, ::Matrix{Float64}, ::ZeroTangent, ::Bool, ::Bool)
  • MethodError: no method matching trsyl!(::Char, ::Char, ::Matrix{ComplexF64}, ::Matrix{ComplexF64}, ::ZeroTangent)
  • MethodError: no method matching size(::ZeroTangent, ::Int64)
  • MethodError: no method matching LowerTriangular(::ZeroTangent) and the list goes on

This is just a quick (incomplete) dump of observations and first thoughts. I may have missed kinds of errors, or said things which are untrue.

Alltogether:

Test Summary:     |  Pass  Error  Broken  Total
ChainRules          | 24330   429    4    24763

mzgubic avatar Jun 11 '21 16:06 mzgubic

I think most of the oens under 3 should be solved via implementing them for ZeroTangent without consideration for projecting. While there are many they are finite in number. and the answers are mostly obvious, because it is general ZeroTanget() because linear operators map zero to zero. In some cases it isn't but we would be able to get through a lot of them, pretty quickly. And it is the same set as for #408

oxinabox avatar Jun 11 '21 16:06 oxinabox

Should every rrule have a method for this, or should it be handled at a higher level, by not calling the rule at all? The former seems like a lot of boilerplate. Maybe the answer depends on this:

In some cases it isn't

What cases are there where a zero tangent should become something nonzero?

mcabbott avatar Jun 30 '21 15:06 mcabbott

The case we have to worry about is functions with multiple inputs, some of which are zero, and some of which are not.

Which as I say that, I realize makes no sense: pullbacks always have 1 input, because julia functions are all single output (it is just sometimes that output is an iterator). So yes, I think this can be handled at a higher level by not calling the pullback at all.

oxinabox avatar Jun 30 '21 19:06 oxinabox

Yes, I guess if your function returns a tuple, then it may have to worry about getting back (Zero(), something). Maybe this one would work?

julia> y, b = pullback(findmax, rand(3));

julia> b((1,2))
([0.0, 1.0, 0.0],)

julia> b((pi,3))
([0.0, 3.141592653589793, 0.0],)

julia> b((pi,"nothing"))
ERROR: MethodError: no method matching +(::Int64, ::String)
  [1] accum(x::Int64, y::String)
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/lib/lib.jl:17
  [2] macro expansion
    @ ~/.julia/packages/Zygote/0da6K/src/lib/lib.jl:27 [inlined]
  [3] accum(x::NamedTuple{(:first, :second), Tuple{Int64, Irrational{:π}}}, y::NamedTuple{(:first, :second), Tuple{String, Float64}})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/lib/lib.jl:27
  [4] accum
    @ ~/.julia/packages/Zygote/0da6K/src/lib/lib.jl:17 [inlined]
  [5] (::typeof(∂(#250)))(Δ::Tuple{Irrational{:π}, String})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./reduce.jl:95 [inlined]
  [7] (::typeof(∂(Base.MappingRF{Base.var"#250#251"{typeof(identity)}, Base.BottomRF{typeof(Base._rf_findmax)}}(Base.var"#250#251"{typeof(identity)}(identity), Base.BottomRF{typeof(Base._rf_findmax)}(Base._rf_findmax)))))(Δ::Tuple{Irrational{:π}, String})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0

mcabbott avatar Jun 30 '21 19:06 mcabbott

(Zero(), something) should be Tangent{Tuple{Int, Float64}(ZeroTangent(), something). Since Tuple is not a valid tangent type since it doesn't support zero or +. Though a lot of our methods do let you pass in any iterator, when you really should only be allowed to pass in a Tangent{Tuple}. (or some otehr iterator that overloads zero and + etc)

oxinabox avatar Jun 30 '21 19:06 oxinabox

What do you mean by "handled at a higher level"? Like handled automatically by the AD system, or making rrule a macro which add a line for treating ZeroTangent automatically? Or is there another way?

mzgubic avatar Jul 01 '21 09:07 mzgubic

Handled in the AD system before calling the pullback. like this line in Zygote https://github.com/FluxML/Zygote.jl/blob/1082ebd3aced63b99c4b6c2956a122ce6a37f97d/src/compiler/chainrules.jl#L94 and this is where we would change Nabla https://github.com/invenia/Nabla.jl/pull/189/files#r662148541

oxinabox avatar Jul 01 '21 09:07 oxinabox