Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Error with control flow

Open mtfishman opened this issue 2 years ago • 4 comments

We've come across a strange bug involving control flow:

julia> using FiniteDifferences

julia> using Zygote

julia> function f(x)
         y = [[x]', [x]]
         r = 0.0
         o = 1.0
         for n in 1:2
           o *= y[n]
           if n < 2
             proj_o = o * [1.0]
           else
             # Error
             proj_o = o
             # Fix
             # proj_o = o * 1.0
           end
           r += proj_o
         end
         return r
       end
f (generic function with 1 method)

julia> x = 1.2
1.2

julia> f(x)
2.6399999999999997

julia> central_fdm(5, 1)(f, x)
3.4000000000000967

julia> f'(x)
ERROR: MethodError: no method matching +(::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
For element-wise addition, use broadcasting with dot syntax: scalar .+ array
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at ~/software/julia-1.7.3/share/julia/base/operators.jl:655
  +(::Union{Float16, Float32, Float64}, ::BigFloat) at ~/software/julia-1.7.3/share/julia/base/mpfr.jl:413
  +(::ChainRulesCore.Tangent{P}, ::P) where P at ~/.julia/packages/ChainRulesCore/GUvJT/src/tangent_arithmetic.jl:146
  ...
Stacktrace:
 [1] accum(x::Float64, y::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
   @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:17
 [2] Pullback
   @ ./REPL[16]:15 [inlined]
 [3] (::typeof(∂(f)))(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [4] (::Zygote.var"#52#53"{typeof(∂(f))})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:41
 [5] (::Zygote.var"#54#55"{typeof(f)})(x::Float64)
   @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:83
 [6] top-level scope
   @ REPL[20]:1

Changing the line:

proj_o = o

to:

proj_o = o * 1.0

fixes the issue and outputs:

julia> f(x)
2.6399999999999997

julia> central_fdm(5, 1)(f, x)
3.4000000000000967

julia> f'(x)
3.4000000000000004

Version information:

julia> versioninfo()
Julia Version 1.7.3
Commit 742b9abb4d (2022-05-06 12:58 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) E-2176M  CPU @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, skylake)
Environment:
  JULIA_EDITOR = vim

julia> using Pkg

julia> Pkg.status("Zygote")
      Status `~/.julia/environments/v1.7/Project.toml`
  [e88e6eb3] Zygote v0.6.40

julia> Pkg.status("ChainRules")
      Status `~/.julia/environments/v1.7/Project.toml`
  [082447d4] ChainRules v1.35.1

Original issue is here: https://github.com/ITensor/ITensors.jl/issues/927

mtfishman avatar Jun 03 '22 20:06 mtfishman

IR dump for debugging:

julia> adj = @code_adjoint f(1.2);

julia> adj.primal
1: (%3, %4 :: Zygote.Context, %1, %2)
  %5 = Zygote._pullback(%4, Base.vect, %2)
  %6 = Base.getindex(%5, 1)
  %7 = Base.getindex(%5, 2)
  %8 = Zygote._pullback(%4, Main.:var"'", %6)
  %9 = Base.getindex(%8, 1)
  %10 = Base.getindex(%8, 2)
  %11 = Zygote._pullback(%4, Base.vect, %2)
  %12 = Base.getindex(%11, 1)
  %13 = Base.getindex(%11, 2)
  %14 = Zygote._pullback(%4, Base.vect, %9, %12)
  %15 = Base.getindex(%14, 1)
  %16 = Base.getindex(%14, 2)
  %17 = Zygote._pullback(%4, Main.:(:), 1, 2)
  %18 = Base.getindex(%17, 1)
  %19 = Base.getindex(%17, 2)
  %20 = Zygote._pullback(%4, Base.iterate, %18)
  %21 = Base.getindex(%20, 1)
  %22 = Base.getindex(%20, 2)
  %23 = %21 === nothing
  %24 = Base.not_int(%23)
  br 6 (0.0, 1) unless %24
  br 2 (%21, 1.0, 0.0, 1)
2: (%25, %26, %27, %59 :: UInt8)
  %28 = Zygote._pullback(%4, Zygote.literal_getfield, %25, Val{1}())
  %29 = Base.getindex(%28, 1)
  %30 = Base.getindex(%28, 2)
  %31 = Zygote._pullback(%4, Zygote.literal_getfield, %25, Val{2}())
  %32 = Base.getindex(%31, 1)
  %33 = Base.getindex(%31, 2)
  %34 = Zygote._pullback(%4, Base.getindex, %15, %29)
  %35 = Base.getindex(%34, 1)
  %36 = Base.getindex(%34, 2)
  %37 = Zygote._pullback(%4, Main.:*, %26, %35)
  %38 = Base.getindex(%37, 1)
  %39 = Base.getindex(%37, 2)
  %40 = Zygote._pullback(%4, Main.:<, %29, 2)
  %41 = Base.getindex(%40, 1)
  %42 = Base.getindex(%40, 2)
  br 4 unless %41
  br 3
3:
  %43 = Zygote._pullback(%4, Base.vect, 1.0)
  %44 = Base.getindex(%43, 1)
  %45 = Base.getindex(%43, 2)
  %46 = Zygote._pullback(%4, Main.:*, %38, %44)
  %47 = Base.getindex(%46, 1)
  %48 = Base.getindex(%46, 2)
  br 5 (%47, 1)
4:
  br 5 (%38, 2)
5: (%49, %60 :: UInt8)
  %50 = Zygote._pullback(%4, Main.:+, %27, %49)
  %51 = Base.getindex(%50, 1)
  %52 = Base.getindex(%50, 2)
  %53 = Zygote._pullback(%4, Base.iterate, %18, %32)
  %54 = Base.getindex(%53, 1)
  %55 = Base.getindex(%53, 2)
  %56 = %54 === nothing
  %57 = Base.not_int(%56)
  br 6 (%51, 2) unless %57
  br 2 (%54, %38, %51, 2)
6: (%58, %61 :: UInt8)
  return %58

julia> adj.adjoint
1: (%1)
  %2 = @61 !== 0x01
  br 6 (nothing, nothing, nothing) unless %2
  br 2 (%1, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
  %8 = @60 !== 0x01
  %9 = (@55)(%4)
  %10 = Zygote.gradindex(%9, 2)
  %11 = Zygote.gradindex(%9, 3)
  %12 = (@52)(%3)
  %13 = Zygote.gradindex(%12, 2)
  %14 = Zygote.gradindex(%12, 3)
  %15 = Zygote.accum(%6, %10)
  %16 = Zygote.accum(%14, %5)
  br 4 unless %8
  br 3
3:
  br 5 (%16)
4:
  %17 = (@48)(%14)
  %18 = Zygote.gradindex(%17, 2)
  %19 = Zygote.gradindex(%17, 3)
  %20 = (@45)(%19)
  %21 = Zygote.accum(%5, %18)
  br 5 (%21)
5: (%22)
  %23 = @59 !== 0x01
  %24 = (@42)(nothing)
  %25 = Zygote.gradindex(%24, 2)
  %26 = (@39)(%22)
  %27 = Zygote.gradindex(%26, 2)
  %28 = Zygote.gradindex(%26, 3)
  %29 = (@36)(%28)
  %30 = Zygote.gradindex(%29, 2)
  %31 = Zygote.gradindex(%29, 3)
  %32 = (@33)(%11)
  %33 = Zygote.gradindex(%32, 2)
  %34 = Zygote.accum(%25, %31)
  %35 = (@30)(%34)
  %36 = Zygote.gradindex(%35, 2)
  %37 = Zygote.accum(%33, %36)
  %38 = Zygote.accum(%7, %30)
  br 6 (%37, %15, %38) unless %23
  br 2 (%13, %37, %27, %15, %38)
6: (%39, %40, %41)
  %42 = (@22)(%39)
  %43 = Zygote.gradindex(%42, 2)
  %44 = Zygote.accum(%40, %43)
  %45 = (@19)(%44)
  %46 = (@16)(%41)
  %47 = Zygote.gradindex(%46, 2)
  %48 = Zygote.gradindex(%46, 3)
  %49 = (@13)(%48)
  %50 = Zygote.gradindex(%49, 2)
  %51 = (@10)(%47)
  %52 = Zygote.gradindex(%51, 2)
  %53 = (@7)(%52)
  %54 = Zygote.gradindex(%53, 2)
  %55 = Zygote.accum(%50, %54)
  %56 = Zygote.tuple(nothing, %55)
  return %56

With pullbacks filled in:

1: (Δ)
  %2 = @61 !== 0x01
  br 6 (nothing, nothing, nothing) unless %2
  br 2 (Δ, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
  %8 = @60 !== 0x01
  %9 = ∂(Base.iterate)(%4)
  %10 = Zygote.gradindex(%9, 2)
  %11 = Zygote.gradindex(%9, 3)
  %12 = ∂(Main.:+)(%3)
  %13 = Zygote.gradindex(%12, 2)
  %14 = Zygote.gradindex(%12, 3)
  %15 = Zygote.accum(%6, %10)
  %16 = Zygote.accum(%14, %5)
  br 4 unless %8
  br 3
3:
  br 5 (%16)
4:
  %17 = ∂(Main.:*)(%14)
  %18 = Zygote.gradindex(%17, 2)
  %19 = Zygote.gradindex(%17, 3)
  %20 = ∂(Base.vect)(%19)
  %21 = Zygote.accum(%5, %18)
  br 5 (%21)
5: (%22)
  %23 = @59 !== 0x01
  %24 = ∂(Main.:<)(nothing)
  %25 = Zygote.gradindex(%24, 2)
  %26 = ∂(Main.:*)(%22)
  %27 = Zygote.gradindex(%26, 2)
  %28 = Zygote.gradindex(%26, 3)
  %29 = ∂(Base.getindex)(%28)
  %30 = Zygote.gradindex(%29, 2)
  %31 = Zygote.gradindex(%29, 3)
  %32 = ∂(Zygote.literal_getfield)(%11)
  %33 = Zygote.gradindex(%32, 2)
  %34 = Zygote.accum(%25, %31)
  %35 = ∂(Zygote.literal_getfield)(%34)
  %36 = Zygote.gradindex(%35, 2)
  %37 = Zygote.accum(%33, %36)
  %38 = Zygote.accum(%7, %30)
  br 6 (%37, %15, %38) unless %23
  br 2 (%13, %37, %27, %15, %38)
6: (%39, %40, %41)
  %42 = ∂(Base.iterate)(%39)
  %43 = Zygote.gradindex(%42, 2)
  %44 = Zygote.accum(%40, %43)
  %45 = ∂(Main.:(:))(%44)
  %46 = ∂(Base.vect)(%41)
  %47 = Zygote.gradindex(%46, 2)
  %48 = Zygote.gradindex(%46, 3)
  %49 = ∂(Base.vect)(%48)
  %50 = Zygote.gradindex(%49, 2)
  %51 = ∂(Main.:var"'")(%47)
  %52 = Zygote.gradindex(%51, 2)
  %53 = ∂(Base.vect)(%52)
  %54 = Zygote.gradindex(%53, 2)
  %55 = Zygote.accum(%50, %54)
  %56 = Zygote.tuple(nothing, %55)
  return %56

ToucheSir avatar Jun 05 '22 00:06 ToucheSir

Another interesting bit: adding any operation around o in the problematic line resolves the issue:

julia> function f(x)
         y = [[x]', [x]]
         r = 0.0
         o = 1.0
         for n in 1:2
           o *= y[n]
           if n < 2
             proj_o = o * [1.0]
           else
             proj_o = identity(o) # @showgrad also works
           end
           r += proj_o
         end
         return r
       end
f (generic function with 1 method)

julia> gradient(f, 1.2)
(3.4000000000000004,)

And the pullback IR:

1: (%1)
  %2 = @64 !== 0x01
  br 6 (nothing, nothing, nothing) unless %2
  br 2 (%1, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
  %8 = @63 !== 0x01
  %9 = (@58)(%4)
  %10 = Zygote.gradindex(%9, 2)
  %11 = Zygote.gradindex(%9, 3)
  %12 = (@55)(%3)
  %13 = Zygote.gradindex(%12, 2)
  %14 = Zygote.gradindex(%12, 3)
  %15 = Zygote.accum(%6, %10)
  br 4 unless %8
  br 3
3:
  %16 = (@51)(%14)
  %17 = Zygote.gradindex(%16, 2)
  %18 = Zygote.accum(%5, %17)
  br 5 (%18)
4:
  %19 = (@48)(%14)
  %20 = Zygote.gradindex(%19, 2)
  %21 = Zygote.gradindex(%19, 3)
  %22 = (@45)(%21)
  %23 = Zygote.accum(%5, %20)
  br 5 (%23)
5: (%24)
  %25 = @62 !== 0x01
  %26 = (@42)(nothing)
  %27 = Zygote.gradindex(%26, 2)
  %28 = (@39)(%24)
  %29 = Zygote.gradindex(%28, 2)
  %30 = Zygote.gradindex(%28, 3)
  %31 = (@36)(%30)
  %32 = Zygote.gradindex(%31, 2)
  %33 = Zygote.gradindex(%31, 3)
  %34 = (@33)(%11)
  %35 = Zygote.gradindex(%34, 2)
  %36 = Zygote.accum(%27, %33)
  %37 = (@30)(%36)
  %38 = Zygote.gradindex(%37, 2)
  %39 = Zygote.accum(%35, %38)
  %40 = Zygote.accum(%7, %32)
  br 6 (%39, %15, %40) unless %25
  br 2 (%13, %39, %29, %15, %40)
6: (%41, %42, %43)
  %44 = (@22)(%41)
  %45 = Zygote.gradindex(%44, 2)
  %46 = Zygote.accum(%42, %45)
  %47 = (@19)(%46)
  %48 = (@16)(%43)
  %49 = Zygote.gradindex(%48, 2)
  %50 = Zygote.gradindex(%48, 3)
  %51 = (@13)(%50)
  %52 = Zygote.gradindex(%51, 2)
  %53 = (@10)(%49)
  %54 = Zygote.gradindex(%53, 2)
  %55 = (@7)(%54)
  %56 = Zygote.gradindex(%55, 2)
  %57 = Zygote.accum(%52, %56)
  %58 = Zygote.tuple(nothing, %57)
  return %58

At first glance, the movement of an accum from block 2 to 3 (nominally the else branch of the if n < 2) seems like the biggest culprit. I was hoping this wouldn't be a compiler issue, but it is increasingly looking like it might be one.

ToucheSir avatar Jun 05 '22 02:06 ToucheSir

Thanks for looking into it! Interesting to see that adding any operation to that line circumvents the bug.

mtfishman avatar Jun 06 '22 22:06 mtfishman

This looks similar to https://github.com/FluxML/Zygote.jl/issues/937#issuecomment-849267874, which seems to have fixed itself somehow (on Zygote v0.6.40).

mcabbott avatar Jun 06 '22 22:06 mcabbott