Allow for `unthunk` to return `nothing`
Aims to fix https://github.com/FluxML/Zygote.jl/issues/1567
In rules defined by @adjoint, there is always a second method back(::Nothing) = nothing, so that the method you write need not allow for nothing. However, the way https://github.com/FluxML/ZygoteRules.jl/pull/17 added unthunk means that if this returns nothing, it does not cause this shortcut.
Making instead a separate method back(Δ::AbstractThunk) = back(unthunk_tangent(Δ)) should avoid that. It assumes that (eventually) unthunk_tangent must give us a non-thunk.
Cc @oschulz and @pxl-th, for work on https://github.com/FluxML/Zygote.jl/pull/966
Sounds sensible to me, but I'm not sure I can judge all implications across Zygote's code.
Besides inference failures, the one failing test is this:
julia> gradient([2 3; 4 5]) do xs
sum([x ^ 2 + y for x in xs, y in xs])
end
([20.0 28.0; 36.0 44.0],)
julia> gradient([2 3; 4 5]) do xs
sum([x ^ i for (i, x) in enumerate(xs)])
end
([1.0 27.0; 8.0 500.0],)
julia> gradient([2 3; 4 5]) do xs
sum([x ^ i + y for (i, x) in enumerate(xs), y in xs])
end == ([8 112; 36 2004],)
ERROR: MethodError: Cannot `convert` an object of type Float64 to an object of type ChainRulesCore.ZeroTangent
The function `convert` exists, but no method is defined for this combination of argument types.
Closest candidates are:
convert(::Type{T}, ::T) where T
@ Base Base.jl:126
Stacktrace:
[1] cvt1
@ ./essentials.jl:612 [inlined]
[2] ntuple
@ ./ntuple.jl:49 [inlined]
[3] convert(::Type{Tuple{ChainRulesCore.ZeroTangent, Float64}}, x::Tuple{Float64, Float64})
@ Base ./essentials.jl:614
[4] setindex!
@ ./array.jl:994 [inlined]
[5] setindex!
@ ./multidimensional.jl:704 [inlined]
[6] macro expansion
@ ./reducedim.jl:289 [inlined]
[7] macro expansion
@ ./simdloop.jl:77 [inlined]
[8] _mapreducedim!(f::Zygote.StaticGetter{1}, op::typeof(Zygote.accum), R::Array{Tuple{…}, 4}, A::Array{Tuple{…}, 4})
@ Base ./reducedim.jl:287
[9] mapreducedim!
@ ./reducedim.jl:296 [inlined]
[10] _mapreduce_dim
@ ./reducedim.jl:340 [inlined]
[11] mapreduce
@ ./reducedim.jl:329 [inlined]
[12] #742
@ ~/.julia/dev/Zygote/src/lib/array.jl:287 [inlined]
[13] map
@ ./tuple.jl:406 [inlined]
[14] productfunc(xs::Tuple{Base.Iterators.Enumerate{Matrix{…}}, Matrix{Int64}}, dy::Array{Tuple{Tuple{…}, Float64}, 4})
@ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:282
[15] product_pullback
@ ~/.julia/dev/Zygote/src/lib/array.jl:295 [inlined]
[16] #3284#back
@ ~/.julia/dev/ZygoteRules/src/adjoint.jl:73 [inlined]
[17] #17
@ ./REPL[8]:2 [inlined]
[18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Int64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[19] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Int64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
[20] gradient(f::Function, args::Matrix{Int64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:154
[21] top-level scope
@ REPL[8]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> show(err)
1-element ExceptionStack:
MethodError: Cannot `convert` an object of type Float64 to an object of type ChainRulesCore.ZeroTangent
The function `convert` exists, but no method is defined for this combination of argument types.
Closest candidates are:
convert(::Type{T}, ::T) where T
@ Base Base.jl:126
Stacktrace:
[1] cvt1
@ ./essentials.jl:612 [inlined]
[2] ntuple
@ ./ntuple.jl:49 [inlined]
[3] convert(::Type{Tuple{ChainRulesCore.ZeroTangent, Float64}}, x::Tuple{Float64, Float64})
@ Base ./essentials.jl:614
[4] setindex!
@ ./array.jl:994 [inlined]
[5] setindex!
@ ./multidimensional.jl:704 [inlined]
[6] macro expansion
@ ./reducedim.jl:289 [inlined]
[7] macro expansion
@ ./simdloop.jl:77 [inlined]
[8] _mapreducedim!(f::Zygote.StaticGetter{1}, op::typeof(Zygote.accum), R::Array{Tuple{ChainRulesCore.ZeroTangent, Float64}, 4}, A::Array{Tuple{Tuple{ChainRulesCore.Thunk{ChainRules.var"#382#416"{Float64, Int64, Int64, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, Int64}}, Float64}, Float64}, 4})
@ Base ./reducedim.jl:287
[9] mapreducedim!
@ ./reducedim.jl:296 [inlined]
[10] _mapreduce_dim
@ ./reducedim.jl:340 [inlined]
[11] mapreduce
@ ./reducedim.jl:329 [inlined]
[12] #742
@ ~/.julia/dev/Zygote/src/lib/array.jl:287 [inlined]
[13] map
@ ./tuple.jl:406 [inlined]
[14] productfunc(xs::Tuple{Base.Iterators.Enumerate{Matrix{Int64}}, Matrix{Int64}}, dy::Array{Tuple{Tuple{ChainRulesCore.Thunk{ChainRules.var"#382#416"{Float64, Int64, Int64, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, Int64}}, Float64}, Float64}, 4})
@ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:282
[15] product_pullback
@ ~/.julia/dev/Zygote/src/lib/array.jl:295 [inlined]
[16] #3284#back
@ ~/.julia/dev/ZygoteRules/src/adjoint.jl:73 [inlined]
...
The offending code is here:
https://github.com/FluxML/Zygote.jl/blob/1b914d994aea236bcb6d3d0cd6c099d86cede101/src/lib/array.jl#L286-L287
And the problem is that zero(::Thunk) isa ZeroTangent:
julia> using ChainRulesCore
julia> @thunk 1+1
Thunk(var"#21#22"())
julia> zero(ans)
ZeroTangent()
although it's not clear to me why this PR exposes that.