ZygoteRules.jl icon indicating copy to clipboard operation
ZygoteRules.jl copied to clipboard

Allow for `unthunk` to return `nothing`

Open mcabbott opened this issue 1 year ago • 2 comments

Aims to fix https://github.com/FluxML/Zygote.jl/issues/1567

In rules defined by @adjoint, there is always a second method back(::Nothing) = nothing, so that the method you write need not allow for nothing. However, the way https://github.com/FluxML/ZygoteRules.jl/pull/17 added unthunk means that if this returns nothing, it does not cause this shortcut.

Making instead a separate method back(Δ::AbstractThunk) = back(unthunk_tangent(Δ)) should avoid that. It assumes that (eventually) unthunk_tangent must give us a non-thunk.

Cc @oschulz and @pxl-th, for work on https://github.com/FluxML/Zygote.jl/pull/966

mcabbott avatar Apr 09 '25 03:04 mcabbott

Sounds sensible to me, but I'm not sure I can judge all implications across Zygote's code.

oschulz avatar Apr 09 '25 09:04 oschulz

Besides inference failures, the one failing test is this:

julia> gradient([2 3; 4 5]) do xs
           sum([x ^ 2 + y for x in xs, y in xs])
       end
([20.0 28.0; 36.0 44.0],)

julia> gradient([2 3; 4 5]) do xs
           sum([x ^ i for (i, x) in enumerate(xs)])
       end
([1.0 27.0; 8.0 500.0],)

julia> gradient([2 3; 4 5]) do xs
           sum([x ^ i + y for (i, x) in enumerate(xs), y in xs])
       end == ([8 112; 36 2004],)
ERROR: MethodError: Cannot `convert` an object of type Float64 to an object of type ChainRulesCore.ZeroTangent
The function `convert` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  convert(::Type{T}, ::T) where T
   @ Base Base.jl:126

Stacktrace:
  [1] cvt1
    @ ./essentials.jl:612 [inlined]
  [2] ntuple
    @ ./ntuple.jl:49 [inlined]
  [3] convert(::Type{Tuple{ChainRulesCore.ZeroTangent, Float64}}, x::Tuple{Float64, Float64})
    @ Base ./essentials.jl:614
  [4] setindex!
    @ ./array.jl:994 [inlined]
  [5] setindex!
    @ ./multidimensional.jl:704 [inlined]
  [6] macro expansion
    @ ./reducedim.jl:289 [inlined]
  [7] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [8] _mapreducedim!(f::Zygote.StaticGetter{1}, op::typeof(Zygote.accum), R::Array{Tuple{…}, 4}, A::Array{Tuple{…}, 4})
    @ Base ./reducedim.jl:287
  [9] mapreducedim!
    @ ./reducedim.jl:296 [inlined]
 [10] _mapreduce_dim
    @ ./reducedim.jl:340 [inlined]
 [11] mapreduce
    @ ./reducedim.jl:329 [inlined]
 [12] #742
    @ ~/.julia/dev/Zygote/src/lib/array.jl:287 [inlined]
 [13] map
    @ ./tuple.jl:406 [inlined]
 [14] productfunc(xs::Tuple{Base.Iterators.Enumerate{Matrix{…}}, Matrix{Int64}}, dy::Array{Tuple{Tuple{…}, Float64}, 4})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:282
 [15] product_pullback
    @ ~/.julia/dev/Zygote/src/lib/array.jl:295 [inlined]
 [16] #3284#back
    @ ~/.julia/dev/ZygoteRules/src/adjoint.jl:73 [inlined]
 [17] #17
    @ ./REPL[8]:2 [inlined]
 [18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Int64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [19] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Int64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
 [20] gradient(f::Function, args::Matrix{Int64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:154
 [21] top-level scope
    @ REPL[8]:1
Some type information was truncated. Use `show(err)` to see complete types.

julia> show(err)
1-element ExceptionStack:
MethodError: Cannot `convert` an object of type Float64 to an object of type ChainRulesCore.ZeroTangent
The function `convert` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  convert(::Type{T}, ::T) where T
   @ Base Base.jl:126

Stacktrace:
  [1] cvt1
    @ ./essentials.jl:612 [inlined]
  [2] ntuple
    @ ./ntuple.jl:49 [inlined]
  [3] convert(::Type{Tuple{ChainRulesCore.ZeroTangent, Float64}}, x::Tuple{Float64, Float64})
    @ Base ./essentials.jl:614
  [4] setindex!
    @ ./array.jl:994 [inlined]
  [5] setindex!
    @ ./multidimensional.jl:704 [inlined]
  [6] macro expansion
    @ ./reducedim.jl:289 [inlined]
  [7] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [8] _mapreducedim!(f::Zygote.StaticGetter{1}, op::typeof(Zygote.accum), R::Array{Tuple{ChainRulesCore.ZeroTangent, Float64}, 4}, A::Array{Tuple{Tuple{ChainRulesCore.Thunk{ChainRules.var"#382#416"{Float64, Int64, Int64, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, Int64}}, Float64}, Float64}, 4})
    @ Base ./reducedim.jl:287
  [9] mapreducedim!
    @ ./reducedim.jl:296 [inlined]
 [10] _mapreduce_dim
    @ ./reducedim.jl:340 [inlined]
 [11] mapreduce
    @ ./reducedim.jl:329 [inlined]
 [12] #742
    @ ~/.julia/dev/Zygote/src/lib/array.jl:287 [inlined]
 [13] map
    @ ./tuple.jl:406 [inlined]
 [14] productfunc(xs::Tuple{Base.Iterators.Enumerate{Matrix{Int64}}, Matrix{Int64}}, dy::Array{Tuple{Tuple{ChainRulesCore.Thunk{ChainRules.var"#382#416"{Float64, Int64, Int64, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, Int64}}, Float64}, Float64}, 4})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:282
 [15] product_pullback
    @ ~/.julia/dev/Zygote/src/lib/array.jl:295 [inlined]
 [16] #3284#back
    @ ~/.julia/dev/ZygoteRules/src/adjoint.jl:73 [inlined]
...

The offending code is here:

https://github.com/FluxML/Zygote.jl/blob/1b914d994aea236bcb6d3d0cd6c099d86cede101/src/lib/array.jl#L286-L287

And the problem is that zero(::Thunk) isa ZeroTangent:

julia> using ChainRulesCore

julia> @thunk 1+1
Thunk(var"#21#22"())

julia> zero(ans)
ZeroTangent()

although it's not clear to me why this PR exposes that.

mcabbott avatar Apr 09 '25 14:04 mcabbott