Zygote.jl
Zygote.jl copied to clipboard
combining loops and addition causes a dimension mismatch
Here is the minimal reproducer I came up with
import Flux
import Zygote
using Functors
struct Test
a
b
end
@functor Test
function (m::Test)(x)
a = x
for f=m.a
a = f(a)
end
b = x
for f=m.b
b = f(b)
end
a + b
end
t = Test([Flux.Dense(10=>5)], [Flux.Dense(10=>5)])
x = rand(10)
Zygote.gradient(() -> sum(t(x)), Flux.params(t))
The error on for f=m.b is
ERROR: LoadError: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 5 and 10")
Stacktrace:
[1] _bcs1
@ ./broadcast.jl:516 [inlined]
[2] _bcs
@ ./broadcast.jl:510 [inlined]
[3] broadcast_shape
@ ./broadcast.jl:504 [inlined]
[4] combine_axes
@ ./broadcast.jl:499 [inlined]
[5] instantiate
@ ./broadcast.jl:281 [inlined]
[6] materialize
@ ./broadcast.jl:860 [inlined]
[7] accum(x::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, ys::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/lib/lib.jl:25
[8] Pullback
@ repro.jl:18 [inlined]
[9] (::typeof(∂(λ)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[10] Pullback
@ repro.jl:26 [inlined]
[11] (::typeof(∂(#3)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[12] (::Zygote.var"#97#98"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, typeof(∂(#3)), Zygote.Context})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:357
[13] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:76
[14] top-level scope
@ repro.jl:26
in expression starting at repro.jl:26
Even though all the dimensions actually match up correctly. Running t(x) is just fine.
The error goes away with any one of the following:
- Replace
a + bwith justaor justb - Replace
a = xwitha = copy(x)(this is the workaround I'm using in my actual code right now) - Replace both for loops with explicit indexing
This is with Zygote version 0.6.41
Even a = identity(x) is enough to stop this. It seems to sometimes get confused that assignment does not permanently identify variables.
Similar to #1236 and https://github.com/FluxML/Zygote.jl/issues/1198 perhaps.
i have the same error and i dont know how to solve it