Zygote.jl
Zygote.jl copied to clipboard
Elide stack generation outside of looping control flow
This PR ports @Keno's work on https://github.com/FluxML/Zygote.jl/pull/78 to 2022 Zygote.
Because IRTools and base Julia have slightly different IR representations, some tweaks were necessary for the core algorithm:
- Instead of inserting phi nodes, we need to add block arguments. This is a bit more tedious because it requires updating multiple blocks.
- On the bright side, we don't need to calculate an iterated dominance frontier for each block. Whether any savings from that are wiped away from calling
IRTools.dominators
I'm not sure. - Blocks are iterated over in reverse order. This allows us to iteratively narrow down the number of unaccounted alpha vars. Although
forward_stacks!
now theoretically runs inO(blocks * alphas)
instead ofO(max(blocks, alphas))
now, in practice the vast majority of alphas will be eliminated very quickly (if not in the first loop iteration).
Performance Comparison
using Zygote, BenchmarkTools
function qux(a, b, x) # Simple control flow
aa = a ? sin(x) : cos(x)
bb = b ? sech(aa) : tanh(aa)
return bb
end
foldminus(xs) = Base.afoldl(-, xs...) # afoldl is very branch-heavy
xs = ntuple(identity, 16)
julia> @time gradient(qux, true, false, 1.0);
0.146199 seconds (60.84 k allocations: 3.519 MiB, 99.73% compilation time) # 0.6.37
0.135723 seconds (52.94 k allocations: 3.086 MiB, 99.86% compilation time) # This PR
julia> @btime gradient(qux, true, false, 1.0);
3.378 μs (46 allocations: 1.31 KiB)
3.044 μs (35 allocations: 720 bytes)
julia> @time gradient(foldminus, xs);
4.785566 seconds (11.53 M allocations: 616.818 MiB, 2.59% gc time, 99.97% compilation time)
4.428252 seconds (11.97 M allocations: 660.290 MiB, 3.03% gc time, 99.97% compilation time)
julia> @btime gradient(foldminus, $xs);
111.256 μs (506 allocations: 20.30 KiB)
151.316 ns (8 allocations: 848 bytes)
The afoldl
example is particularly interesting because of how that function is defined. Despite the presence of a loop at the end, not requiring stacks for the block of conditionals is significantly faster. This could have immediate downstream impact for code like https://github.com/FluxML/Flux.jl/pull/1809#discussion_r777762381.
Next Steps
The Zygote test suite passes locally for me, so if CI + downstream is green then I think this should be a drop-in replacement for the current compiler code path. Per the comments, more optimizations may be possible for aspects such as calculating self-reachability. After looking through a bunch of IRTools code, there's probably a lot of low hanging fruit to optimize there as well.
Wow
Awesome, really nice work @ToucheSir. If this is based on @Keno's original code it probably makes sense to add a co-author to the commit? (Alternatively you could treat this as an update to his branch, but that might be a hassle.)
I may be able to help with review if I get some time (but please don't wait up if someone else gets there first).
Thanks @MikeInnes! Treating this as a branch update is a little beyond my ability since the original PR was filed before the IRTools transition, but I've now tagged the commit with co-authorship info.
Trying to track references in issues, the guess is that this is the solution to https://github.com/TuringLang/Turing.jl/issues/1754 or am I missing something?
If so, is this PR sufficiently solid that it can be checked (on julia 1.7) or should I wait until it is merged?
Please do check this. It may not make too much difference in the compilation but it should help with control flow heavy code. Besides it's a good idea to test against Turing in general. We should add that to our downstream tests if we can get a subsection of the testset that sufficiently checks for Zygote correctness.
Friendly bump on this :)
I just came across this, and I'll that this is huge for anything that uses DIstributions.jl (which we do in Turing.jl) due to the amount of if-statements in StatsFuns.jl/Distributions.jl. I've literally shaved off days of runtime for certain large models with Zygote by spending a grueling amount of effort tracking down if-statements in StatsFuns.jl and removing them.
I'm currently trying to do some benchmarks to see exactly what sort of effect it has on both runtime and compile time for our use-cases.
@ToucheSir would you rebase?
So it unfortuantely seems to significantly increase compilation time (and memory usage) in the example in https://github.com/TuringLang/Turing.jl/issues/1754. For 15 tilde-statements, it blows out my 32GB mem laptop using this PR while the memory overhead for the current release (I haven't tested against master
) has a minimal memory usage (it still takes ages to compile).
Regarding the increase in compile-time, you can also observe this for the currently running tets, e.g. DiffEqFlux.jl/Layers
. Atm it has been running for ~6hrs, while in the previously merged PR it seems to have only taken ~20mins: https://github.com/FluxML/Zygote.jl/actions/runs/3260268471/jobs/5353708714
2/4 failures on nightly and all failures on stable+LTS should be squashed now. The remaining 2 nightly ones are because of a missing rule and have been reported at https://github.com/JuliaDiff/ChainRules.jl/issues/684.
e.g.
DiffEqFlux.jl/Layers
. Atm it has been running for ~6hrs, while in the previously merged PR it seems to have only taken ~20mins:
This one has been mysteriously timing out before this PR as well. I'll have another look at https://github.com/TuringLang/Turing.jl/issues/1754 though. Last I checked (around the time of https://github.com/FluxML/Zygote.jl/pull/1195#issuecomment-1116549912) the changes here didn't make a difference to latency, so perhaps the compiler has become smarter since...