Zygote.jl
Zygote.jl copied to clipboard
Error with control flow
We've come across a strange bug involving control flow:
julia> using FiniteDifferences
julia> using Zygote
julia> function f(x)
y = [[x]', [x]]
r = 0.0
o = 1.0
for n in 1:2
o *= y[n]
if n < 2
proj_o = o * [1.0]
else
# Error
proj_o = o
# Fix
# proj_o = o * 1.0
end
r += proj_o
end
return r
end
f (generic function with 1 method)
julia> x = 1.2
1.2
julia> f(x)
2.6399999999999997
julia> central_fdm(5, 1)(f, x)
3.4000000000000967
julia> f'(x)
ERROR: MethodError: no method matching +(::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
For element-wise addition, use broadcasting with dot syntax: scalar .+ array
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at ~/software/julia-1.7.3/share/julia/base/operators.jl:655
+(::Union{Float16, Float32, Float64}, ::BigFloat) at ~/software/julia-1.7.3/share/julia/base/mpfr.jl:413
+(::ChainRulesCore.Tangent{P}, ::P) where P at ~/.julia/packages/ChainRulesCore/GUvJT/src/tangent_arithmetic.jl:146
...
Stacktrace:
[1] accum(x::Float64, y::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:17
[2] Pullback
@ ./REPL[16]:15 [inlined]
[3] (::typeof(∂(f)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[4] (::Zygote.var"#52#53"{typeof(∂(f))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:41
[5] (::Zygote.var"#54#55"{typeof(f)})(x::Float64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:83
[6] top-level scope
@ REPL[20]:1
Changing the line:
proj_o = o
to:
proj_o = o * 1.0
fixes the issue and outputs:
julia> f(x)
2.6399999999999997
julia> central_fdm(5, 1)(f, x)
3.4000000000000967
julia> f'(x)
3.4000000000000004
Version information:
julia> versioninfo()
Julia Version 1.7.3
Commit 742b9abb4d (2022-05-06 12:58 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Xeon(R) E-2176M CPU @ 2.70GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-12.0.1 (ORCJIT, skylake)
Environment:
JULIA_EDITOR = vim
julia> using Pkg
julia> Pkg.status("Zygote")
Status `~/.julia/environments/v1.7/Project.toml`
[e88e6eb3] Zygote v0.6.40
julia> Pkg.status("ChainRules")
Status `~/.julia/environments/v1.7/Project.toml`
[082447d4] ChainRules v1.35.1
Original issue is here: https://github.com/ITensor/ITensors.jl/issues/927
IR dump for debugging:
julia> adj = @code_adjoint f(1.2);
julia> adj.primal
1: (%3, %4 :: Zygote.Context, %1, %2)
%5 = Zygote._pullback(%4, Base.vect, %2)
%6 = Base.getindex(%5, 1)
%7 = Base.getindex(%5, 2)
%8 = Zygote._pullback(%4, Main.:var"'", %6)
%9 = Base.getindex(%8, 1)
%10 = Base.getindex(%8, 2)
%11 = Zygote._pullback(%4, Base.vect, %2)
%12 = Base.getindex(%11, 1)
%13 = Base.getindex(%11, 2)
%14 = Zygote._pullback(%4, Base.vect, %9, %12)
%15 = Base.getindex(%14, 1)
%16 = Base.getindex(%14, 2)
%17 = Zygote._pullback(%4, Main.:(:), 1, 2)
%18 = Base.getindex(%17, 1)
%19 = Base.getindex(%17, 2)
%20 = Zygote._pullback(%4, Base.iterate, %18)
%21 = Base.getindex(%20, 1)
%22 = Base.getindex(%20, 2)
%23 = %21 === nothing
%24 = Base.not_int(%23)
br 6 (0.0, 1) unless %24
br 2 (%21, 1.0, 0.0, 1)
2: (%25, %26, %27, %59 :: UInt8)
%28 = Zygote._pullback(%4, Zygote.literal_getfield, %25, Val{1}())
%29 = Base.getindex(%28, 1)
%30 = Base.getindex(%28, 2)
%31 = Zygote._pullback(%4, Zygote.literal_getfield, %25, Val{2}())
%32 = Base.getindex(%31, 1)
%33 = Base.getindex(%31, 2)
%34 = Zygote._pullback(%4, Base.getindex, %15, %29)
%35 = Base.getindex(%34, 1)
%36 = Base.getindex(%34, 2)
%37 = Zygote._pullback(%4, Main.:*, %26, %35)
%38 = Base.getindex(%37, 1)
%39 = Base.getindex(%37, 2)
%40 = Zygote._pullback(%4, Main.:<, %29, 2)
%41 = Base.getindex(%40, 1)
%42 = Base.getindex(%40, 2)
br 4 unless %41
br 3
3:
%43 = Zygote._pullback(%4, Base.vect, 1.0)
%44 = Base.getindex(%43, 1)
%45 = Base.getindex(%43, 2)
%46 = Zygote._pullback(%4, Main.:*, %38, %44)
%47 = Base.getindex(%46, 1)
%48 = Base.getindex(%46, 2)
br 5 (%47, 1)
4:
br 5 (%38, 2)
5: (%49, %60 :: UInt8)
%50 = Zygote._pullback(%4, Main.:+, %27, %49)
%51 = Base.getindex(%50, 1)
%52 = Base.getindex(%50, 2)
%53 = Zygote._pullback(%4, Base.iterate, %18, %32)
%54 = Base.getindex(%53, 1)
%55 = Base.getindex(%53, 2)
%56 = %54 === nothing
%57 = Base.not_int(%56)
br 6 (%51, 2) unless %57
br 2 (%54, %38, %51, 2)
6: (%58, %61 :: UInt8)
return %58
julia> adj.adjoint
1: (%1)
%2 = @61 !== 0x01
br 6 (nothing, nothing, nothing) unless %2
br 2 (%1, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
%8 = @60 !== 0x01
%9 = (@55)(%4)
%10 = Zygote.gradindex(%9, 2)
%11 = Zygote.gradindex(%9, 3)
%12 = (@52)(%3)
%13 = Zygote.gradindex(%12, 2)
%14 = Zygote.gradindex(%12, 3)
%15 = Zygote.accum(%6, %10)
%16 = Zygote.accum(%14, %5)
br 4 unless %8
br 3
3:
br 5 (%16)
4:
%17 = (@48)(%14)
%18 = Zygote.gradindex(%17, 2)
%19 = Zygote.gradindex(%17, 3)
%20 = (@45)(%19)
%21 = Zygote.accum(%5, %18)
br 5 (%21)
5: (%22)
%23 = @59 !== 0x01
%24 = (@42)(nothing)
%25 = Zygote.gradindex(%24, 2)
%26 = (@39)(%22)
%27 = Zygote.gradindex(%26, 2)
%28 = Zygote.gradindex(%26, 3)
%29 = (@36)(%28)
%30 = Zygote.gradindex(%29, 2)
%31 = Zygote.gradindex(%29, 3)
%32 = (@33)(%11)
%33 = Zygote.gradindex(%32, 2)
%34 = Zygote.accum(%25, %31)
%35 = (@30)(%34)
%36 = Zygote.gradindex(%35, 2)
%37 = Zygote.accum(%33, %36)
%38 = Zygote.accum(%7, %30)
br 6 (%37, %15, %38) unless %23
br 2 (%13, %37, %27, %15, %38)
6: (%39, %40, %41)
%42 = (@22)(%39)
%43 = Zygote.gradindex(%42, 2)
%44 = Zygote.accum(%40, %43)
%45 = (@19)(%44)
%46 = (@16)(%41)
%47 = Zygote.gradindex(%46, 2)
%48 = Zygote.gradindex(%46, 3)
%49 = (@13)(%48)
%50 = Zygote.gradindex(%49, 2)
%51 = (@10)(%47)
%52 = Zygote.gradindex(%51, 2)
%53 = (@7)(%52)
%54 = Zygote.gradindex(%53, 2)
%55 = Zygote.accum(%50, %54)
%56 = Zygote.tuple(nothing, %55)
return %56
With pullbacks filled in:
1: (Δ)
%2 = @61 !== 0x01
br 6 (nothing, nothing, nothing) unless %2
br 2 (Δ, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
%8 = @60 !== 0x01
%9 = ∂(Base.iterate)(%4)
%10 = Zygote.gradindex(%9, 2)
%11 = Zygote.gradindex(%9, 3)
%12 = ∂(Main.:+)(%3)
%13 = Zygote.gradindex(%12, 2)
%14 = Zygote.gradindex(%12, 3)
%15 = Zygote.accum(%6, %10)
%16 = Zygote.accum(%14, %5)
br 4 unless %8
br 3
3:
br 5 (%16)
4:
%17 = ∂(Main.:*)(%14)
%18 = Zygote.gradindex(%17, 2)
%19 = Zygote.gradindex(%17, 3)
%20 = ∂(Base.vect)(%19)
%21 = Zygote.accum(%5, %18)
br 5 (%21)
5: (%22)
%23 = @59 !== 0x01
%24 = ∂(Main.:<)(nothing)
%25 = Zygote.gradindex(%24, 2)
%26 = ∂(Main.:*)(%22)
%27 = Zygote.gradindex(%26, 2)
%28 = Zygote.gradindex(%26, 3)
%29 = ∂(Base.getindex)(%28)
%30 = Zygote.gradindex(%29, 2)
%31 = Zygote.gradindex(%29, 3)
%32 = ∂(Zygote.literal_getfield)(%11)
%33 = Zygote.gradindex(%32, 2)
%34 = Zygote.accum(%25, %31)
%35 = ∂(Zygote.literal_getfield)(%34)
%36 = Zygote.gradindex(%35, 2)
%37 = Zygote.accum(%33, %36)
%38 = Zygote.accum(%7, %30)
br 6 (%37, %15, %38) unless %23
br 2 (%13, %37, %27, %15, %38)
6: (%39, %40, %41)
%42 = ∂(Base.iterate)(%39)
%43 = Zygote.gradindex(%42, 2)
%44 = Zygote.accum(%40, %43)
%45 = ∂(Main.:(:))(%44)
%46 = ∂(Base.vect)(%41)
%47 = Zygote.gradindex(%46, 2)
%48 = Zygote.gradindex(%46, 3)
%49 = ∂(Base.vect)(%48)
%50 = Zygote.gradindex(%49, 2)
%51 = ∂(Main.:var"'")(%47)
%52 = Zygote.gradindex(%51, 2)
%53 = ∂(Base.vect)(%52)
%54 = Zygote.gradindex(%53, 2)
%55 = Zygote.accum(%50, %54)
%56 = Zygote.tuple(nothing, %55)
return %56
Another interesting bit: adding any operation around o
in the problematic line resolves the issue:
julia> function f(x)
y = [[x]', [x]]
r = 0.0
o = 1.0
for n in 1:2
o *= y[n]
if n < 2
proj_o = o * [1.0]
else
proj_o = identity(o) # @showgrad also works
end
r += proj_o
end
return r
end
f (generic function with 1 method)
julia> gradient(f, 1.2)
(3.4000000000000004,)
And the pullback IR:
1: (%1)
%2 = @64 !== 0x01
br 6 (nothing, nothing, nothing) unless %2
br 2 (%1, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
%8 = @63 !== 0x01
%9 = (@58)(%4)
%10 = Zygote.gradindex(%9, 2)
%11 = Zygote.gradindex(%9, 3)
%12 = (@55)(%3)
%13 = Zygote.gradindex(%12, 2)
%14 = Zygote.gradindex(%12, 3)
%15 = Zygote.accum(%6, %10)
br 4 unless %8
br 3
3:
%16 = (@51)(%14)
%17 = Zygote.gradindex(%16, 2)
%18 = Zygote.accum(%5, %17)
br 5 (%18)
4:
%19 = (@48)(%14)
%20 = Zygote.gradindex(%19, 2)
%21 = Zygote.gradindex(%19, 3)
%22 = (@45)(%21)
%23 = Zygote.accum(%5, %20)
br 5 (%23)
5: (%24)
%25 = @62 !== 0x01
%26 = (@42)(nothing)
%27 = Zygote.gradindex(%26, 2)
%28 = (@39)(%24)
%29 = Zygote.gradindex(%28, 2)
%30 = Zygote.gradindex(%28, 3)
%31 = (@36)(%30)
%32 = Zygote.gradindex(%31, 2)
%33 = Zygote.gradindex(%31, 3)
%34 = (@33)(%11)
%35 = Zygote.gradindex(%34, 2)
%36 = Zygote.accum(%27, %33)
%37 = (@30)(%36)
%38 = Zygote.gradindex(%37, 2)
%39 = Zygote.accum(%35, %38)
%40 = Zygote.accum(%7, %32)
br 6 (%39, %15, %40) unless %25
br 2 (%13, %39, %29, %15, %40)
6: (%41, %42, %43)
%44 = (@22)(%41)
%45 = Zygote.gradindex(%44, 2)
%46 = Zygote.accum(%42, %45)
%47 = (@19)(%46)
%48 = (@16)(%43)
%49 = Zygote.gradindex(%48, 2)
%50 = Zygote.gradindex(%48, 3)
%51 = (@13)(%50)
%52 = Zygote.gradindex(%51, 2)
%53 = (@10)(%49)
%54 = Zygote.gradindex(%53, 2)
%55 = (@7)(%54)
%56 = Zygote.gradindex(%55, 2)
%57 = Zygote.accum(%52, %56)
%58 = Zygote.tuple(nothing, %57)
return %58
At first glance, the movement of an accum
from block 2 to 3 (nominally the else branch of the if n < 2
) seems like the biggest culprit. I was hoping this wouldn't be a compiler issue, but it is increasingly looking like it might be one.
Thanks for looking into it! Interesting to see that adding any operation to that line circumvents the bug.
This looks similar to https://github.com/FluxML/Zygote.jl/issues/937#issuecomment-849267874, which seems to have fixed itself somehow (on Zygote v0.6.40).