Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Type instability with conditional return

Open sethaxen opened this issue 5 years ago • 3 comments

In the below example with a conditional, the types should be easily inferable. However, none of the types in the backwards pass are inferred. Why would this be?

using Zygote

function foo(x)
  if x > 0
    a, b = 2x, 3x
  else
    a, b = -2x, -3x
  end
  return a, b
end

julia> back = Zygote._pullback(foo, 1.0)[2]
julia> @code_warntype back((1.0, 2.0))
Variables
  #self#::typeof(∂(foo))
  Δ::Tuple{Float64,Float64}
  phi_4_1::Any

Body::Tuple{Nothing,Any}
1 ─       $(Expr(:meta, :inline))
│   %2  = Base.getfield(#self#, :t)::Any
│   %3  = Base.getindex(%2, 7)::Any
│   %4  = Base.getindex(%2, 6)::Any
│   %5  = Base.getindex(%2, 1)::Any
│   %6  = (%3 !== 0x01)::Bool
│   %7  = (%4)(Δ)::Any
│   %8  = Zygote.gradindex(%7, 2)::Any
│   %9  = Zygote.gradindex(%7, 3)::Any
│   %10 = Base.getindex(%2, 2)::Any
│   %11 = Zygote.Stack(%10)::Zygote.Stack{_A} where _A
│   %12 = Base.getindex(%2, 3)::Any
│   %13 = Zygote.Stack(%12)::Zygote.Stack{_A} where _A
│   %14 = Base.getindex(%2, 4)::Any
│   %15 = Zygote.Stack(%14)::Zygote.Stack{_A} where _A
│   %16 = Base.getindex(%2, 5)::Any
│   %17 = Zygote.Stack(%16)::Zygote.Stack{_A} where _A
└──       goto #4 if not %6
2 ─       goto #3
3 ─ %20 = Base.pop!(%17)::Any
│   %21 = Base.pop!(%15)::Any
│   %22 = (%20)(%9)::Any
│   %23 = Zygote.gradindex(%22, 3)::Any
│   %24 = (%21)(%8)::Any
│   %25 = Zygote.gradindex(%24, 3)::Any
│   %26 = Zygote.accum(%23, %25)::Any
│         (phi_4_1 = %26)
└──       goto #5
4 ─ %29 = Base.pop!(%13)::Any
│   %30 = Base.pop!(%11)::Any
│   %31 = (%30)(%9)::Any
│   %32 = Zygote.gradindex(%31, 3)::Any
│   %33 = (%29)(%8)::Any
│   %34 = Zygote.gradindex(%33, 3)::Any
│   %35 = Zygote.accum(%32, %34)::Any
│         (phi_4_1 = %35)
└──       goto #5
5 ┄       nothing
│   %39 = (%5)(nothing)::Any
│   %40 = Zygote.gradindex(%39, 2)::Any
│   %41 = Zygote.accum(phi_4_1, %40)::Any
│   %42 = Zygote.tuple(nothing, %41)::Core.Compiler.PartialStruct(Tuple{Nothing,Any}, Any[Core.Compiler.Const(nothing, false), Any])
└──       return %42

Version info:

julia> versioninfo()
Julia Version 1.2.0
Commit c6da87ff4b (2019-08-20 00:03 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin18.6.0)
  CPU: Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-6.0.1 (ORCJIT, haswell)

(v1.2) pkg> status Zygote
    Status `~/.julia/environments/v1.2/Project.toml`
  [b552c78f] DiffRules v0.0.10
  [f6369f11] ForwardDiff v0.10.3
  [7869d1d1] IRTools v0.2.3 #master (https://github.com/MikeInnes/IRTools.jl.git)
  [276daf66] SpecialFunctions v0.7.2
  [e88e6eb3] Zygote v0.3.4 #master (https://github.com/FluxML/Zygote.jl.git)
  [700de1a5] ZygoteRules v0.2.0 #master (https://github.com/FluxML/ZygoteRules.jl.git)

sethaxen avatar Oct 14 '19 08:10 sethaxen

Currently, this is expected; as part of the control flow handling, the pullbacks for a and b will get stored on a Vector{Any} stack, which will lose type information.

We're hoping that new compiler features will help with this in future, but for now it's not that high a priority.

MikeInnes avatar Nov 04 '19 12:11 MikeInnes

Have there been any updates on this? I am using some hand-written special functions and the type-unstable code caused by this issue is very slow in comparison to ForwardDiff...

kaandocal avatar Feb 09 '23 09:02 kaandocal

There has not. You could try https://github.com/FluxML/Zygote.jl/pull/1195 to see if it performs any better for you, but generally the better ideas are to hide branching behind a function with a custom rule or to use branchless control flow like ifelse.

ToucheSir avatar Feb 09 '23 16:02 ToucheSir