size error when unbroadcasting arrays with generic eltypes
I've run into a DimensionMismatch error when broadcasting arrays with generic element types. The general flavor of calculation I want to do is like this:
using Zygote
function okay(a, b)
sum(a .+ b)
end
a = rand(2,5,3)
b = rand(2,5,1,4)
@show Zygote.withgradient(okay, a, b)
with normal outputs
Zygote.withgradient(okay, a, b) = (val = 105.52328765180295, grad = ([4.0 4.0 4.0 4.0 4.0; 4.0 4.0 4.0 4.0 4.0;;; 4.0 4.0 4.0 4.0 4.0; 4.0 4.0 4.0 4.0 4.0;;; 4.0 4.0 4.0 4.0 4.0; 4.0 4.0 4.0 4.0 4.0], [3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0;;;; 3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0;;;; 3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0;;;; 3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0]))
however when I wrap each array element in a named tuple as follows
p(a, b) = a.x + b.x
function mwe(a, b)
sum(p.(a, b))
end
ag = NamedTuple{(:x,)}.(tuple.(a))
bg = NamedTuple{(:x,)}.(tuple.(b))
@show Zygote.withgradient(mwe, ag, bg)
I get
ERROR: LoadError: DimensionMismatch: variable with size(x) == (2, 5, 3) cannot have a gradient with size(dx) == (30,)
Stacktrace:
[1] (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{elements::Array{ChainRulesCore.ProjectTo{…}, 3}, axes::Tuple{Base.OneTo{…}, Base.OneTo{…}, Base.OneTo{…}}}})(dx::Vector{ChainRulesCore.Tangent{@NamedTuple{x::Float64}, @NamedTuple{x::Float64}}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/XAgYn/src/projection.jl:229
[2] _project
@ /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/chainrules.jl:200 [inlined]
[3] unbroadcast(x::Array{@NamedTuple{x::Float64}, 3}, maybethunked_x̄::Array{@NamedTuple{x::Float64}, 4})
@ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/lib/broadcast.jl:63
[4] map
@ ./tuple.jl:383 [inlined]
[5] ∇broadcasted
@ /local/home/lxvm/projects/contrib/dev/Zygote/src/lib/broadcast.jl:222 [inlined]
[6] (::Zygote.var"#4145#back#1372"{Zygote.var"#∇broadcasted#1383"{Tuple{…}, Array{…}, Val{…}}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#719#722"{…}}, ChainRules.var"#718#721"{Float64, Colon}})
@ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
[7] #305
@ /local/home/lxvm/projects/contrib/dev/Zygote/src/lib/lib.jl:214 [inlined]
[8] #2189#back
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72 [inlined]
[9] broadcasted
@ ./broadcast.jl:1331 [inlined]
[10] mwe
@ /local/home/lxvm/projects/contrib/zygote_mwe2.jl:13 [inlined]
[11] (::Zygote.Pullback{Tuple{typeof(mwe), Array{@NamedTuple{…}, 3}, Array{@NamedTuple{…}, 4}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{}}, Zygote.ZBack{ChainRules.var"#sum_pullback#720"{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}})(Δ::Float64)
@ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/interface2.jl:0
[12] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{typeof(mwe), Array{@NamedTuple{…}, 3}, Array{@NamedTuple{…}, 4}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{}}, Zygote.ZBack{ChainRules.var"#sum_pullback#720"{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}}})(Δ::Float64)
@ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/interface.jl:97
[13] withgradient(::Function, ::Array{@NamedTuple{x::Float64}, 3}, ::Vararg{Any})
@ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/interface.jl:219
[14] macro expansion
@ show.jl:1232 [inlined]
[15] top-level scope
@ /local/home/lxvm/projects/contrib/zygote_mwe2.jl:18
[16] include(fname::String)
@ Main ./sysimg.jl:38
[17] top-level scope
@ REPL[3]:1
in expression starting at /local/home/lxvm/projects/contrib/zygote_mwe2.jl:18
Some type information was truncated. Use `show(err)` to see complete types.
I've worked out that an unintended flattening of arrays happens in the following line
https://github.com/FluxML/Zygote.jl/blob/e0af1a814b9c1b652861eda5db27ddec13a28d16/src/compiler/chainrules.jl#L297
here, map will flatten arrays when I believe they are intended to be broadcasted, at least in the context of unbroadcast. Perhaps unbroadcast should drop trailing singleton dimensions? I am not sure what the correct fix is.
I think I agree with your diagnosis. ProjectTo{AbstractArray should sort out the trailing dimensions, but doesn't get a chance because map switches from its N-dim form, to its just-iterate form:
julia> Zygote.z2d(fill((; x=1.0), 2, 3), fill((; x=rand()), 2, 3))
2×3 Matrix{ChainRulesCore.Tangent{@NamedTuple{x::Float64}, @NamedTuple{x::Float64}}}:
Tangent{@NamedTuple{x::Float64}}(x = 1.0,) … Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
Tangent{@NamedTuple{x::Float64}}(x = 1.0,) Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
julia> Zygote.z2d(fill((; x=1.0), 2, 3, 1), fill((; x=rand()), 2, 3))
6-element Vector{ChainRulesCore.Tangent{@NamedTuple{x::Float64}, @NamedTuple{x::Float64}}}:
Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
The fix might be to just avoid map and use broadcast there:
julia> Zygote.z2d(dx::AbstractArray, primal::AbstractArray) = broadcast(Zygote.z2d, dx, primal)
julia> Zygote.withgradient(mwe, ag, bg)[2][1] # previously the error above
2×5×3 Array{@NamedTuple{x::Float64}, 3}:
[:, :, 1] =
(x = 4.0,) (x = 4.0,) (x = 4.0,) (x = 4.0,) (x = 4.0,)
(x = 4.0,) (x = 4.0,) (x = 4.0,) (x = 4.0,) (x = 4.0,)
[:, :, 2] =
(x = 4.0,) (x = 4.0,) (x = 4.0,) (x = 4.0,) (x = 4.0,)