Zygote.jl
Zygote.jl copied to clipboard
Type instability with conditional return
In the below example with a conditional, the types should be easily inferable. However, none of the types in the backwards pass are inferred. Why would this be?
using Zygote
function foo(x)
if x > 0
a, b = 2x, 3x
else
a, b = -2x, -3x
end
return a, b
end
julia> back = Zygote._pullback(foo, 1.0)[2]
julia> @code_warntype back((1.0, 2.0))
Variables
#self#::typeof(∂(foo))
Δ::Tuple{Float64,Float64}
phi_4_1::Any
Body::Tuple{Nothing,Any}
1 ─ $(Expr(:meta, :inline))
│ %2 = Base.getfield(#self#, :t)::Any
│ %3 = Base.getindex(%2, 7)::Any
│ %4 = Base.getindex(%2, 6)::Any
│ %5 = Base.getindex(%2, 1)::Any
│ %6 = (%3 !== 0x01)::Bool
│ %7 = (%4)(Δ)::Any
│ %8 = Zygote.gradindex(%7, 2)::Any
│ %9 = Zygote.gradindex(%7, 3)::Any
│ %10 = Base.getindex(%2, 2)::Any
│ %11 = Zygote.Stack(%10)::Zygote.Stack{_A} where _A
│ %12 = Base.getindex(%2, 3)::Any
│ %13 = Zygote.Stack(%12)::Zygote.Stack{_A} where _A
│ %14 = Base.getindex(%2, 4)::Any
│ %15 = Zygote.Stack(%14)::Zygote.Stack{_A} where _A
│ %16 = Base.getindex(%2, 5)::Any
│ %17 = Zygote.Stack(%16)::Zygote.Stack{_A} where _A
└── goto #4 if not %6
2 ─ goto #3
3 ─ %20 = Base.pop!(%17)::Any
│ %21 = Base.pop!(%15)::Any
│ %22 = (%20)(%9)::Any
│ %23 = Zygote.gradindex(%22, 3)::Any
│ %24 = (%21)(%8)::Any
│ %25 = Zygote.gradindex(%24, 3)::Any
│ %26 = Zygote.accum(%23, %25)::Any
│ (phi_4_1 = %26)
└── goto #5
4 ─ %29 = Base.pop!(%13)::Any
│ %30 = Base.pop!(%11)::Any
│ %31 = (%30)(%9)::Any
│ %32 = Zygote.gradindex(%31, 3)::Any
│ %33 = (%29)(%8)::Any
│ %34 = Zygote.gradindex(%33, 3)::Any
│ %35 = Zygote.accum(%32, %34)::Any
│ (phi_4_1 = %35)
└── goto #5
5 ┄ nothing
│ %39 = (%5)(nothing)::Any
│ %40 = Zygote.gradindex(%39, 2)::Any
│ %41 = Zygote.accum(phi_4_1, %40)::Any
│ %42 = Zygote.tuple(nothing, %41)::Core.Compiler.PartialStruct(Tuple{Nothing,Any}, Any[Core.Compiler.Const(nothing, false), Any])
└── return %42
Version info:
julia> versioninfo()
Julia Version 1.2.0
Commit c6da87ff4b (2019-08-20 00:03 UTC)
Platform Info:
OS: macOS (x86_64-apple-darwin18.6.0)
CPU: Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-6.0.1 (ORCJIT, haswell)
(v1.2) pkg> status Zygote
Status `~/.julia/environments/v1.2/Project.toml`
[b552c78f] DiffRules v0.0.10
[f6369f11] ForwardDiff v0.10.3
[7869d1d1] IRTools v0.2.3 #master (https://github.com/MikeInnes/IRTools.jl.git)
[276daf66] SpecialFunctions v0.7.2
[e88e6eb3] Zygote v0.3.4 #master (https://github.com/FluxML/Zygote.jl.git)
[700de1a5] ZygoteRules v0.2.0 #master (https://github.com/FluxML/ZygoteRules.jl.git)
Currently, this is expected; as part of the control flow handling, the pullbacks for a
and b
will get stored on a Vector{Any}
stack, which will lose type information.
We're hoping that new compiler features will help with this in future, but for now it's not that high a priority.
Have there been any updates on this? I am using some hand-written special functions and the type-unstable code caused by this issue is very slow in comparison to ForwardDiff...
There has not. You could try https://github.com/FluxML/Zygote.jl/pull/1195 to see if it performs any better for you, but generally the better ideas are to hide branching behind a function with a custom rule or to use branchless control flow like ifelse
.