Zygote.jl copied to clipboard
Error with control flow
We've come across a strange bug involving control flow:
julia> using FiniteDifferences
julia> using Zygote
julia> function f(x)
y = [[x]', [x]]
r = 0.0
o = 1.0
for n in 1:2
o *= y[n]
if n < 2
proj_o = o * [1.0]
# Error
proj_o = o
# Fix
# proj_o = o * 1.0
r += proj_o
return r
f (generic function with 1 method)
julia> x = 1.2
julia> f(x)
julia> central_fdm(5, 1)(f, x)
julia> f'(x)
ERROR: MethodError: no method matching +(::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
For element-wise addition, use broadcasting with dot syntax: scalar .+ array
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at ~/software/julia-1.7.3/share/julia/base/operators.jl:655
+(::Union{Float16, Float32, Float64}, ::BigFloat) at ~/software/julia-1.7.3/share/julia/base/mpfr.jl:413
+(::ChainRulesCore.Tangent{P}, ::P) where P at ~/.julia/packages/ChainRulesCore/GUvJT/src/tangent_arithmetic.jl:146
[1] accum(x::Float64, y::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:17
[2] Pullback
@ ./REPL[16]:15 [inlined]
[3] (::typeof(∂(f)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[4] (::Zygote.var"#52#53"{typeof(∂(f))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:41
[5] (::Zygote.var"#54#55"{typeof(f)})(x::Float64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:83
[6] top-level scope
@ REPL[20]:1
Changing the line:
proj_o = o
proj_o = o * 1.0
fixes the issue and outputs:
julia> f(x)
julia> central_fdm(5, 1)(f, x)
julia> f'(x)
Version information:
julia> versioninfo()
Julia Version 1.7.3
Commit 742b9abb4d (2022-05-06 12:58 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Xeon(R) E-2176M CPU @ 2.70GHz
LIBM: libopenlibm
LLVM: libLLVM-12.0.1 (ORCJIT, skylake)
julia> using Pkg
julia> Pkg.status("Zygote")
Status `~/.julia/environments/v1.7/Project.toml`
[e88e6eb3] Zygote v0.6.40
julia> Pkg.status("ChainRules")
Status `~/.julia/environments/v1.7/Project.toml`
[082447d4] ChainRules v1.35.1
Original issue is here: https://github.com/ITensor/ITensors.jl/issues/927
IR dump for debugging:
julia> adj = @code_adjoint f(1.2);
julia> adj.primal
1: (%3, %4 :: Zygote.Context, %1, %2)
%5 = Zygote._pullback(%4, Base.vect, %2)
%6 = Base.getindex(%5, 1)
%7 = Base.getindex(%5, 2)
%8 = Zygote._pullback(%4, Main.:var"'", %6)
%9 = Base.getindex(%8, 1)
%10 = Base.getindex(%8, 2)
%11 = Zygote._pullback(%4, Base.vect, %2)
%12 = Base.getindex(%11, 1)
%13 = Base.getindex(%11, 2)
%14 = Zygote._pullback(%4, Base.vect, %9, %12)
%15 = Base.getindex(%14, 1)
%16 = Base.getindex(%14, 2)
%17 = Zygote._pullback(%4, Main.:(:), 1, 2)
%18 = Base.getindex(%17, 1)
%19 = Base.getindex(%17, 2)
%20 = Zygote._pullback(%4, Base.iterate, %18)
%21 = Base.getindex(%20, 1)
%22 = Base.getindex(%20, 2)
%23 = %21 === nothing
%24 = Base.not_int(%23)
br 6 (0.0, 1) unless %24
br 2 (%21, 1.0, 0.0, 1)
2: (%25, %26, %27, %59 :: UInt8)
%28 = Zygote._pullback(%4, Zygote.literal_getfield, %25, Val{1}())
%29 = Base.getindex(%28, 1)
%30 = Base.getindex(%28, 2)
%31 = Zygote._pullback(%4, Zygote.literal_getfield, %25, Val{2}())
%32 = Base.getindex(%31, 1)
%33 = Base.getindex(%31, 2)
%34 = Zygote._pullback(%4, Base.getindex, %15, %29)
%35 = Base.getindex(%34, 1)
%36 = Base.getindex(%34, 2)
%37 = Zygote._pullback(%4, Main.:*, %26, %35)
%38 = Base.getindex(%37, 1)
%39 = Base.getindex(%37, 2)
%40 = Zygote._pullback(%4, Main.:<, %29, 2)
%41 = Base.getindex(%40, 1)
%42 = Base.getindex(%40, 2)
br 4 unless %41
br 3
%43 = Zygote._pullback(%4, Base.vect, 1.0)
%44 = Base.getindex(%43, 1)
%45 = Base.getindex(%43, 2)
%46 = Zygote._pullback(%4, Main.:*, %38, %44)
%47 = Base.getindex(%46, 1)
%48 = Base.getindex(%46, 2)
br 5 (%47, 1)
br 5 (%38, 2)
5: (%49, %60 :: UInt8)
%50 = Zygote._pullback(%4, Main.:+, %27, %49)
%51 = Base.getindex(%50, 1)
%52 = Base.getindex(%50, 2)
%53 = Zygote._pullback(%4, Base.iterate, %18, %32)
%54 = Base.getindex(%53, 1)
%55 = Base.getindex(%53, 2)
%56 = %54 === nothing
%57 = Base.not_int(%56)
br 6 (%51, 2) unless %57
br 2 (%54, %38, %51, 2)
6: (%58, %61 :: UInt8)
return %58
julia> adj.adjoint
1: (%1)
%2 = @61 !== 0x01
br 6 (nothing, nothing, nothing) unless %2
br 2 (%1, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
%8 = @60 !== 0x01
%9 = (@55)(%4)
%10 = Zygote.gradindex(%9, 2)
%11 = Zygote.gradindex(%9, 3)
%12 = (@52)(%3)
%13 = Zygote.gradindex(%12, 2)
%14 = Zygote.gradindex(%12, 3)
%15 = Zygote.accum(%6, %10)
%16 = Zygote.accum(%14, %5)
br 4 unless %8
br 3
br 5 (%16)
%17 = (@48)(%14)
%18 = Zygote.gradindex(%17, 2)
%19 = Zygote.gradindex(%17, 3)
%20 = (@45)(%19)
%21 = Zygote.accum(%5, %18)
br 5 (%21)
5: (%22)
%23 = @59 !== 0x01
%24 = (@42)(nothing)
%25 = Zygote.gradindex(%24, 2)
%26 = (@39)(%22)
%27 = Zygote.gradindex(%26, 2)
%28 = Zygote.gradindex(%26, 3)
%29 = (@36)(%28)
%30 = Zygote.gradindex(%29, 2)
%31 = Zygote.gradindex(%29, 3)
%32 = (@33)(%11)
%33 = Zygote.gradindex(%32, 2)
%34 = Zygote.accum(%25, %31)
%35 = (@30)(%34)
%36 = Zygote.gradindex(%35, 2)
%37 = Zygote.accum(%33, %36)
%38 = Zygote.accum(%7, %30)
br 6 (%37, %15, %38) unless %23
br 2 (%13, %37, %27, %15, %38)
6: (%39, %40, %41)
%42 = (@22)(%39)
%43 = Zygote.gradindex(%42, 2)
%44 = Zygote.accum(%40, %43)
%45 = (@19)(%44)
%46 = (@16)(%41)
%47 = Zygote.gradindex(%46, 2)
%48 = Zygote.gradindex(%46, 3)
%49 = (@13)(%48)
%50 = Zygote.gradindex(%49, 2)
%51 = (@10)(%47)
%52 = Zygote.gradindex(%51, 2)
%53 = (@7)(%52)
%54 = Zygote.gradindex(%53, 2)
%55 = Zygote.accum(%50, %54)
%56 = Zygote.tuple(nothing, %55)
return %56
With pullbacks filled in:
1: (Δ)
%2 = @61 !== 0x01
br 6 (nothing, nothing, nothing) unless %2
br 2 (Δ, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
%8 = @60 !== 0x01
%9 = ∂(Base.iterate)(%4)
%10 = Zygote.gradindex(%9, 2)
%11 = Zygote.gradindex(%9, 3)
%12 = ∂(Main.:+)(%3)
%13 = Zygote.gradindex(%12, 2)
%14 = Zygote.gradindex(%12, 3)
%15 = Zygote.accum(%6, %10)
%16 = Zygote.accum(%14, %5)
br 4 unless %8
br 3
br 5 (%16)
%17 = ∂(Main.:*)(%14)
%18 = Zygote.gradindex(%17, 2)
%19 = Zygote.gradindex(%17, 3)
%20 = ∂(Base.vect)(%19)
%21 = Zygote.accum(%5, %18)
br 5 (%21)
5: (%22)
%23 = @59 !== 0x01
%24 = ∂(Main.:<)(nothing)
%25 = Zygote.gradindex(%24, 2)
%26 = ∂(Main.:*)(%22)
%27 = Zygote.gradindex(%26, 2)
%28 = Zygote.gradindex(%26, 3)
%29 = ∂(Base.getindex)(%28)
%30 = Zygote.gradindex(%29, 2)
%31 = Zygote.gradindex(%29, 3)
%32 = ∂(Zygote.literal_getfield)(%11)
%33 = Zygote.gradindex(%32, 2)
%34 = Zygote.accum(%25, %31)
%35 = ∂(Zygote.literal_getfield)(%34)
%36 = Zygote.gradindex(%35, 2)
%37 = Zygote.accum(%33, %36)
%38 = Zygote.accum(%7, %30)
br 6 (%37, %15, %38) unless %23
br 2 (%13, %37, %27, %15, %38)
6: (%39, %40, %41)
%42 = ∂(Base.iterate)(%39)
%43 = Zygote.gradindex(%42, 2)
%44 = Zygote.accum(%40, %43)
%45 = ∂(Main.:(:))(%44)
%46 = ∂(Base.vect)(%41)
%47 = Zygote.gradindex(%46, 2)
%48 = Zygote.gradindex(%46, 3)
%49 = ∂(Base.vect)(%48)
%50 = Zygote.gradindex(%49, 2)
%51 = ∂(Main.:var"'")(%47)
%52 = Zygote.gradindex(%51, 2)
%53 = ∂(Base.vect)(%52)
%54 = Zygote.gradindex(%53, 2)
%55 = Zygote.accum(%50, %54)
%56 = Zygote.tuple(nothing, %55)
return %56
Another interesting bit: adding any operation around o
in the problematic line resolves the issue:
julia> function f(x)
y = [[x]', [x]]
r = 0.0
o = 1.0
for n in 1:2
o *= y[n]
if n < 2
proj_o = o * [1.0]
proj_o = identity(o) # @showgrad also works
r += proj_o
return r
f (generic function with 1 method)
julia> gradient(f, 1.2)
And the pullback IR:
1: (%1)
%2 = @64 !== 0x01
br 6 (nothing, nothing, nothing) unless %2
br 2 (%1, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
%8 = @63 !== 0x01
%9 = (@58)(%4)
%10 = Zygote.gradindex(%9, 2)
%11 = Zygote.gradindex(%9, 3)
%12 = (@55)(%3)
%13 = Zygote.gradindex(%12, 2)
%14 = Zygote.gradindex(%12, 3)
%15 = Zygote.accum(%6, %10)
br 4 unless %8
br 3
%16 = (@51)(%14)
%17 = Zygote.gradindex(%16, 2)
%18 = Zygote.accum(%5, %17)
br 5 (%18)
%19 = (@48)(%14)
%20 = Zygote.gradindex(%19, 2)
%21 = Zygote.gradindex(%19, 3)
%22 = (@45)(%21)
%23 = Zygote.accum(%5, %20)
br 5 (%23)
5: (%24)
%25 = @62 !== 0x01
%26 = (@42)(nothing)
%27 = Zygote.gradindex(%26, 2)
%28 = (@39)(%24)
%29 = Zygote.gradindex(%28, 2)
%30 = Zygote.gradindex(%28, 3)
%31 = (@36)(%30)
%32 = Zygote.gradindex(%31, 2)
%33 = Zygote.gradindex(%31, 3)
%34 = (@33)(%11)
%35 = Zygote.gradindex(%34, 2)
%36 = Zygote.accum(%27, %33)
%37 = (@30)(%36)
%38 = Zygote.gradindex(%37, 2)
%39 = Zygote.accum(%35, %38)
%40 = Zygote.accum(%7, %32)
br 6 (%39, %15, %40) unless %25
br 2 (%13, %39, %29, %15, %40)
6: (%41, %42, %43)
%44 = (@22)(%41)
%45 = Zygote.gradindex(%44, 2)
%46 = Zygote.accum(%42, %45)
%47 = (@19)(%46)
%48 = (@16)(%43)
%49 = Zygote.gradindex(%48, 2)
%50 = Zygote.gradindex(%48, 3)
%51 = (@13)(%50)
%52 = Zygote.gradindex(%51, 2)
%53 = (@10)(%49)
%54 = Zygote.gradindex(%53, 2)
%55 = (@7)(%54)
%56 = Zygote.gradindex(%55, 2)
%57 = Zygote.accum(%52, %56)
%58 = Zygote.tuple(nothing, %57)
return %58
At first glance, the movement of an accum
from block 2 to 3 (nominally the else branch of the if n < 2
) seems like the biggest culprit. I was hoping this wouldn't be a compiler issue, but it is increasingly looking like it might be one.
Thanks for looking into it! Interesting to see that adding any operation to that line circumvents the bug.
This looks similar to https://github.com/FluxML/Zygote.jl/issues/937#issuecomment-849267874, which seems to have fixed itself somehow (on Zygote v0.6.40).