Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

size error when unbroadcasting arrays with generic eltypes

Open lxvm opened this issue 3 months ago • 1 comments

I've run into a DimensionMismatch error when broadcasting arrays with generic element types. The general flavor of calculation I want to do is like this:

using Zygote

function okay(a, b)
    sum(a .+ b)
end

a = rand(2,5,3)
b = rand(2,5,1,4)
@show Zygote.withgradient(okay, a, b)

with normal outputs

Zygote.withgradient(okay, a, b) = (val = 105.52328765180295, grad = ([4.0 4.0 4.0 4.0 4.0; 4.0 4.0 4.0 4.0 4.0;;; 4.0 4.0 4.0 4.0 4.0; 4.0 4.0 4.0 4.0 4.0;;; 4.0 4.0 4.0 4.0 4.0; 4.0 4.0 4.0 4.0 4.0], [3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0;;;; 3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0;;;; 3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0;;;; 3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0]))

however when I wrap each array element in a named tuple as follows

p(a, b) = a.x + b.x
function mwe(a, b)
    sum(p.(a, b))
end

ag = NamedTuple{(:x,)}.(tuple.(a))
bg = NamedTuple{(:x,)}.(tuple.(b))
@show Zygote.withgradient(mwe, ag, bg)

I get

ERROR: LoadError: DimensionMismatch: variable with size(x) == (2, 5, 3) cannot have a gradient with size(dx) == (30,)
Stacktrace:
  [1] (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{elements::Array{ChainRulesCore.ProjectTo{…}, 3}, axes::Tuple{Base.OneTo{…}, Base.OneTo{…}, Base.OneTo{…}}}})(dx::Vector{ChainRulesCore.Tangent{@NamedTuple{x::Float64}, @NamedTuple{x::Float64}}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/XAgYn/src/projection.jl:229
  [2] _project
    @ /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/chainrules.jl:200 [inlined]
  [3] unbroadcast(x::Array{@NamedTuple{x::Float64}, 3}, maybethunked_x̄::Array{@NamedTuple{x::Float64}, 4})
    @ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/lib/broadcast.jl:63
  [4] map
    @ ./tuple.jl:383 [inlined]
  [5] ∇broadcasted
    @ /local/home/lxvm/projects/contrib/dev/Zygote/src/lib/broadcast.jl:222 [inlined]
  [6] (::Zygote.var"#4145#back#1372"{Zygote.var"#∇broadcasted#1383"{Tuple{…}, Array{…}, Val{…}}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#719#722"{…}}, ChainRules.var"#718#721"{Float64, Colon}})
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
  [7] #305
    @ /local/home/lxvm/projects/contrib/dev/Zygote/src/lib/lib.jl:214 [inlined]
  [8] #2189#back
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72 [inlined]
  [9] broadcasted
    @ ./broadcast.jl:1331 [inlined]
 [10] mwe
    @ /local/home/lxvm/projects/contrib/zygote_mwe2.jl:13 [inlined]
 [11] (::Zygote.Pullback{Tuple{typeof(mwe), Array{@NamedTuple{…}, 3}, Array{@NamedTuple{…}, 4}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{}}, Zygote.ZBack{ChainRules.var"#sum_pullback#720"{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}})(Δ::Float64)
    @ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{typeof(mwe), Array{@NamedTuple{…}, 3}, Array{@NamedTuple{…}, 4}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{}}, Zygote.ZBack{ChainRules.var"#sum_pullback#720"{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}}})(Δ::Float64)
    @ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/interface.jl:97
 [13] withgradient(::Function, ::Array{@NamedTuple{x::Float64}, 3}, ::Vararg{Any})
    @ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/interface.jl:219
 [14] macro expansion
    @ show.jl:1232 [inlined]
 [15] top-level scope
    @ /local/home/lxvm/projects/contrib/zygote_mwe2.jl:18
 [16] include(fname::String)
    @ Main ./sysimg.jl:38
 [17] top-level scope
    @ REPL[3]:1
in expression starting at /local/home/lxvm/projects/contrib/zygote_mwe2.jl:18
Some type information was truncated. Use `show(err)` to see complete types.

I've worked out that an unintended flattening of arrays happens in the following line https://github.com/FluxML/Zygote.jl/blob/e0af1a814b9c1b652861eda5db27ddec13a28d16/src/compiler/chainrules.jl#L297 here, map will flatten arrays when I believe they are intended to be broadcasted, at least in the context of unbroadcast. Perhaps unbroadcast should drop trailing singleton dimensions? I am not sure what the correct fix is.

lxvm avatar Aug 26 '25 19:08 lxvm

I think I agree with your diagnosis. ProjectTo{AbstractArray should sort out the trailing dimensions, but doesn't get a chance because map switches from its N-dim form, to its just-iterate form:

julia> Zygote.z2d(fill((; x=1.0), 2, 3), fill((; x=rand()), 2, 3))
2×3 Matrix{ChainRulesCore.Tangent{@NamedTuple{x::Float64}, @NamedTuple{x::Float64}}}:
 Tangent{@NamedTuple{x::Float64}}(x = 1.0,)  …  Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
 Tangent{@NamedTuple{x::Float64}}(x = 1.0,)     Tangent{@NamedTuple{x::Float64}}(x = 1.0,)

julia> Zygote.z2d(fill((; x=1.0), 2, 3, 1), fill((; x=rand()), 2, 3))
6-element Vector{ChainRulesCore.Tangent{@NamedTuple{x::Float64}, @NamedTuple{x::Float64}}}:
 Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
 Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
 Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
 Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
 Tangent{@NamedTuple{x::Float64}}(x = 1.0,)
 Tangent{@NamedTuple{x::Float64}}(x = 1.0,)

The fix might be to just avoid map and use broadcast there:

julia> Zygote.z2d(dx::AbstractArray, primal::AbstractArray) = broadcast(Zygote.z2d, dx, primal)

julia> Zygote.withgradient(mwe, ag, bg)[2][1]  # previously the error above
2×5×3 Array{@NamedTuple{x::Float64}, 3}:
[:, :, 1] =
 (x = 4.0,)  (x = 4.0,)  (x = 4.0,)  (x = 4.0,)  (x = 4.0,)
 (x = 4.0,)  (x = 4.0,)  (x = 4.0,)  (x = 4.0,)  (x = 4.0,)

[:, :, 2] =
 (x = 4.0,)  (x = 4.0,)  (x = 4.0,)  (x = 4.0,)  (x = 4.0,)

mcabbott avatar Sep 09 '25 18:09 mcabbott