Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Elide stack generation outside of looping control flow

Open ToucheSir opened this issue 2 years ago • 6 comments

This PR ports @Keno's work on https://github.com/FluxML/Zygote.jl/pull/78 to 2022 Zygote.

Because IRTools and base Julia have slightly different IR representations, some tweaks were necessary for the core algorithm:

  1. Instead of inserting phi nodes, we need to add block arguments. This is a bit more tedious because it requires updating multiple blocks.
  2. On the bright side, we don't need to calculate an iterated dominance frontier for each block. Whether any savings from that are wiped away from calling IRTools.dominators I'm not sure.
  3. Blocks are iterated over in reverse order. This allows us to iteratively narrow down the number of unaccounted alpha vars. Although forward_stacks! now theoretically runs in O(blocks * alphas) instead of O(max(blocks, alphas)) now, in practice the vast majority of alphas will be eliminated very quickly (if not in the first loop iteration).

Performance Comparison

using Zygote, BenchmarkTools

function qux(a, b, x) # Simple control flow
   aa = a ? sin(x) : cos(x)
   bb = b ? sech(aa) : tanh(aa)
   return bb
end

foldminus(xs) = Base.afoldl(-, xs...) # afoldl is very branch-heavy

xs = ntuple(identity, 16)
julia> @time gradient(qux, true, false, 1.0);
  0.146199 seconds (60.84 k allocations: 3.519 MiB, 99.73% compilation time) # 0.6.37
  0.135723 seconds (52.94 k allocations: 3.086 MiB, 99.86% compilation time) # This PR

julia> @btime gradient(qux, true, false, 1.0);
  3.378 μs (46 allocations: 1.31 KiB)
  3.044 μs (35 allocations: 720 bytes)

julia> @time gradient(foldminus, xs);
  4.785566 seconds (11.53 M allocations: 616.818 MiB, 2.59% gc time, 99.97% compilation time)
  4.428252 seconds (11.97 M allocations: 660.290 MiB, 3.03% gc time, 99.97% compilation time)

julia> @btime gradient(foldminus, $xs);
  111.256 μs (506 allocations: 20.30 KiB)
  151.316 ns (8 allocations: 848 bytes)

The afoldl example is particularly interesting because of how that function is defined. Despite the presence of a loop at the end, not requiring stacks for the block of conditionals is significantly faster. This could have immediate downstream impact for code like https://github.com/FluxML/Flux.jl/pull/1809#discussion_r777762381.

Next Steps

The Zygote test suite passes locally for me, so if CI + downstream is green then I think this should be a drop-in replacement for the current compiler code path. Per the comments, more optimizations may be possible for aspects such as calculating self-reachability. After looking through a bunch of IRTools code, there's probably a lot of low hanging fruit to optimize there as well.

ToucheSir avatar Apr 05 '22 06:04 ToucheSir

Wow

CarloLucibello avatar Apr 05 '22 07:04 CarloLucibello

Awesome, really nice work @ToucheSir. If this is based on @Keno's original code it probably makes sense to add a co-author to the commit? (Alternatively you could treat this as an update to his branch, but that might be a hassle.)

I may be able to help with review if I get some time (but please don't wait up if someone else gets there first).

MikeInnes avatar Apr 22 '22 13:04 MikeInnes

Thanks @MikeInnes! Treating this as a branch update is a little beyond my ability since the original PR was filed before the IRTools transition, but I've now tagged the commit with co-authorship info.

ToucheSir avatar Apr 23 '22 22:04 ToucheSir

Trying to track references in issues, the guess is that this is the solution to https://github.com/TuringLang/Turing.jl/issues/1754 or am I missing something?

If so, is this PR sufficiently solid that it can be checked (on julia 1.7) or should I wait until it is merged?

jlperla avatar May 03 '22 20:05 jlperla

Please do check this. It may not make too much difference in the compilation but it should help with control flow heavy code. Besides it's a good idea to test against Turing in general. We should add that to our downstream tests if we can get a subsection of the testset that sufficiently checks for Zygote correctness.

DhairyaLGandhi avatar May 03 '22 20:05 DhairyaLGandhi

Friendly bump on this :)

ToucheSir avatar Aug 10 '22 05:08 ToucheSir

I just came across this, and I'll that this is huge for anything that uses DIstributions.jl (which we do in Turing.jl) due to the amount of if-statements in StatsFuns.jl/Distributions.jl. I've literally shaved off days of runtime for certain large models with Zygote by spending a grueling amount of effort tracking down if-statements in StatsFuns.jl and removing them.

I'm currently trying to do some benchmarks to see exactly what sort of effect it has on both runtime and compile time for our use-cases.

torfjelde avatar Nov 10 '22 18:11 torfjelde

@ToucheSir would you rebase?

CarloLucibello avatar Nov 10 '22 18:11 CarloLucibello

So it unfortuantely seems to significantly increase compilation time (and memory usage) in the example in https://github.com/TuringLang/Turing.jl/issues/1754. For 15 tilde-statements, it blows out my 32GB mem laptop using this PR while the memory overhead for the current release (I haven't tested against master) has a minimal memory usage (it still takes ages to compile).

torfjelde avatar Nov 11 '22 13:11 torfjelde

Regarding the increase in compile-time, you can also observe this for the currently running tets, e.g. DiffEqFlux.jl/Layers. Atm it has been running for ~6hrs, while in the previously merged PR it seems to have only taken ~20mins: https://github.com/FluxML/Zygote.jl/actions/runs/3260268471/jobs/5353708714

torfjelde avatar Nov 11 '22 13:11 torfjelde

2/4 failures on nightly and all failures on stable+LTS should be squashed now. The remaining 2 nightly ones are because of a missing rule and have been reported at https://github.com/JuliaDiff/ChainRules.jl/issues/684.

e.g. DiffEqFlux.jl/Layers. Atm it has been running for ~6hrs, while in the previously merged PR it seems to have only taken ~20mins:

This one has been mysteriously timing out before this PR as well. I'll have another look at https://github.com/TuringLang/Turing.jl/issues/1754 though. Last I checked (around the time of https://github.com/FluxML/Zygote.jl/pull/1195#issuecomment-1116549912) the changes here didn't make a difference to latency, so perhaps the compiler has become smarter since...

ToucheSir avatar Nov 11 '22 14:11 ToucheSir