Turing.jl
Turing.jl copied to clipboard
Zygote's compilation scales badly with the number of `~` statements
I'm not certain whether or not this can be considered a "Turing.jl-issue" or not, but I figured I would at least raise it as an issue here so people are aware.
The compilation time of Zygote scales quite badly with the number of ~ statements.
TL;DR: it takes almost 5 minutes to compile a model with 14 ~ statements. I don't have the result here, but at some point I tried one with 20 ~ statements, and it took a full ~23 mins to compile.
Demo
using Turing, Zygote
Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();
results = []
num_tildes = 0
Running the following snippet a couple of times we get a sense of the compilation times:
num_tildes += 1
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
t = @elapsed Turing.Core.gradient_logp(
adbackend,
vi[spl],
vi,
model,
spl
);
push!(results, t)
results
14-element Vector{Any}:
15.917412882
9.600213651
14.911253238
24.699050206
71.811476745
49.158314248
46.059394601
57.44494627
75.514564551
94.927956369
134.165383535
156.079943416
202.745837585
273.93479382
That is, it takes almost 5 minutes to compile a model with 14 ~ statements. I don't have the result here, but at some point I tried one with 20 ~ statements, and it took a full ~23 mins to compile.
Additional info
versioninfo()
Julia Version 1.6.2
Commit 1b93d53fc4 (2021-07-14 15:36 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Core(TM) i7-10710U CPU @ 1.10GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
Pkg.status()
Status `/tmp/jl_PHYdgF/Project.toml`
[fce5fe82] Turing v0.19.2
[e88e6eb3] Zygote v0.6.32
Thanks a lot for posting this!
Do we have code that checks the compile times of models? How long has Zygote compilation been taking this long?
How long has Zygote compilation been taking this long?
Since Julia 1.6 afaik.
How long has Zygote compilation been taking this long?
Since Julia 1.6 afaik.
As in, the 1.6 update caused it? Is the compilation faster on 1.5?
@torfjelde probably can get the right answer, if it matters. There are a lot of packages and versions interacting so not sure it necessarily matches a particular Julia version. It might take just as long to find the exact combination of Julia, zygote, chainrules, Turing, dynamicppl, etc versions that caused it as it would to actually solve the problem.
As in, the 1.6 update caused it? Is the compilation faster on 1.5?
I don't think so. It's only my experience instead of any benchmarking/profiling (which we need!) I don't have a good answer to this question.
Sorry to bump so soon @torfjelde but we have any sense of a timeline on this? Are we thinking a week, a month, etc.? We would help out but I fear this is a little too tightly connected to the PPL macros for us to contribute to.
I'm sorry, I can't put a timeline on this right now.
And I worry and suspect it's not particularly related to Turing, unfortunately :confused: There was one change we made in Turing that I worried could have caused it, but when I tried with an older version the performance was still horrible.
The best way to identify what's wrong is to just run the test above for different versions of Turing and Zygote.
As of right now I have the following results:
| Julia | Args | Turing | Zygote | 1 | 5 | 10 | 15 | 20 |
|---|---|---|---|---|---|---|---|---|
| 1.6.1 | 0.19 | 0.6 | 16.387332656 | 89.504752487 | 101.654405869 | 320.501781565 | N/A | |
| 1.6.1 | 0.18 | 0.6 | 15.740250148 | 72.75134579 | 92.373805461 | 310.690419646 | N/A | |
| 1.6.1 | 0.16 | 0.6 | 14.098175146 | 77.075566728 | 86.093118485 | 310.681374129 | N/A | |
| 1.6.1 | 0.15 | 0.6 | 13.674902932 | 75.142343745 | 83.177268954 | 309.479831351 | N/A | |
| 1.6.1 | 0.19.3 | 0.6.32 | 16.519126809 | 76.212964076 | 102.056456525 | 331.835923261 | N/A | |
| 1.6.1 | 0.19.3 | 0.6.30 | 16.416369038 | 73.552713013 | 98.009949908 | 325.177503251 | N/A | |
| 1.6.1 | 0.19.3 | 0.6.28 | 15.907693813 | 73.859609986 | 99.65385215 | 320.406822231 | N/A | |
| 1.6.1 | 0.19.3 | 0.6.25 | 16.155023105 | 78.113216025 | 91.386990818 | 323.182170592 | N/A | |
| 1.6.1 | 0.19.3 | 0.6.20 | 16.18978755 | 76.922928732 | 92.863883057 | 329.472662078 | N/A | |
| 1.6.1 | 0.19.3 | 0.6.17 | 16.674484174 | 45.536406584 | 91.975029367 | 324.107966298 | N/A | |
| 1.6.1 | 0.19.3 | 0.6.15 | 16.647893857 | 77.246173583 | 92.547574565 | 325.843565657 | N/A | |
| 1.6.5 | 0.19.3 | 0.6.32 | 16.507995835 | 75.232973774 | 102.738249637 | 320.804476864 | N/A | |
| 1.6.5 | 0.19.3 | 0.6.30 | 16.442945162 | 75.552669923 | 100.079247493 | 315.767280417 | N/A | |
| 1.6.5 | 0.19.3 | 0.6.28 | 15.696215339 | 41.383477294 | 93.911315279 | 329.02145826 | N/A | |
| 1.6.5 | -O1 |
0.19.3 | 0.6.27 | 9.442187946 | 13.78550593 | 27.309818329 | 61.662899872 | |
| 1.6.5 | -O1 |
0.19.3 | 0.6.26 | 9.525658865 | 15.737620623 | 37.788878332 | 154.77089758 | N/A |
| 1.6.5 | 0.19.3 | 0.6.25 | 16.103407208 | 78.827783495 | 92.628645679 | 323.749846052 | N/A | |
| 1.6.5 | 0.19.3 | 0.6.20 | 16.081257806 | 45.193458428 | 91.298156568 | 324.887103193 | N/A | |
| 1.6.5 | 0.19.3 | 0.6.17 | 16.661767166 | 117.089430786 | 91.768988429 | 322.823669619 | N/A | |
| 1.6.5 | 0.19.3 | 0.6.15 | 16.698116787 | 99.878925099 | 91.444128363 | 322.810816215 | N/A | |
| 1.6.5 | -O1 |
0.19.3 | 0.6.33 | 9.949009976 | 14.697673851 | 29.464735085 | 66.522595927 | 132.225028216 |
| 1.6.5 | -O1 |
0.19.3 | 0.6.33 | 9.952505584 | 14.642345133 | 29.487756023 | 66.405126925 | 132.858871417 |
| 1.6.5 | -O1 |
0.19.3 | 0.6.32 | 10.213924716 | 14.90292511 | 29.43141585 | 65.774123029 | 132.584280038 |
| 1.6.5 | -O1 |
0.19.3 | 0.6.30 | 9.952360792 | 14.604632874 | 29.263546539 | 65.256550668 | 131.766387963 |
| 1.6.5 | -O1 |
0.19.3 | 0.6.28 | 9.61100625 | 14.16199726 | 28.106580684 | 61.752193716 | 128.178242579 |
| 1.6.5 | 0.19.3 | #master |
16.492658709 | 76.261172697 | 99.347007007 | 318.410333441 | ||
| 1.6.5 | 0.19.3 | mcabbot:opt_level |
14.510376473 | 15.581238372 | 30.034043196 | 67.305733329 | ||
| 1.8.2 | 0.21.13 | 0.6.49 (#66cc60) |
16.799219872 | 48.980447826 | 182.013203064 | 1359.453378186 | N/A | |
| 1.8.2 | 0.21.13 | 0.6.49 (#ee8945) |
32.313442694 | 104.451519939 | 284.157696506 | N/A | N/A |
Columns with numeric values represents the number of ~ statements.
This is on the following system:
julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Core(TM) i7-6850K CPU @ 3.60GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-11.0.1 (ORCJIT, broadwell)
The sad news is that I can't really glean anything useful from the above :confused: Here it seems to be "fine" (taking 5mins to compile is not good, but it's much better than the reported numbers).
Comments:
- When I tried running this benchmark on my laptop (which has 32G memory), Zygote earlier than 0.6.17 blew up (on Julia 1.6.3), i.e. it seems as if the memory-usage of Zygote changed significantly from 0.6.17 to 0.6.18?
- The version columns which are missing patch-version, e.g.
0.19, are using the most recent version of Zygote which is compatible with the corresponding Turing-version. I can find these out if need be (the numbers are from an earlier version of the script I'm using).
[2022-01-10 Mon 00:36]: I'm now running the same benchmarks on Julia 1.6.5 just to make sure it has nothing to do with weird interactions between the particular Julia version and Zygote.
Script I'm running
using Pkg; Pkg.activate(mktempdir())
TURING_VERSION = ENV["TURING_VERSION"]
ZYGOTE_VERSION = ENV["ZYGOTE_VERSION"]
@info "Trying to install Turing@$(TURING_VERSION) and Zygote@$(ZYGOTE_VERSION)"
Pkg.add(name="Turing", version=TURING_VERSION)
Pkg.add(name="Zygote", version=ZYGOTE_VERSION)
using Turing, Zygote
pkgversion(mod) = Pkg.TOML.parsefile(joinpath(dirname(Pkg.project().path), "Manifest.toml"))[string(mod)][1]["version"]
@info "Installed Turing@$(pkgversion(Turing)) and Zygote@$(pkgversion(Zygote))"
Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();
results = []
num_tildes = 1
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
t = @elapsed Turing.Core.gradient_logp(
adbackend,
vi[spl],
vi,
model,
spl
);
push!(results, t)
@info "Result" t
num_tildes = 5
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
t = @elapsed Turing.Core.gradient_logp(
adbackend,
vi[spl],
vi,
model,
spl
);
push!(results, t)
@info "Result" t
num_tildes = 10
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
t = @elapsed Turing.Core.gradient_logp(
adbackend,
vi[spl],
vi,
model,
spl
);
push!(results, t)
@info "Result" t
num_tildes = 15
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
t = @elapsed Turing.Core.gradient_logp(
adbackend,
vi[spl],
vi,
model,
spl
);
push!(results, t)
@info "Result" t
println(join(results, " | "))
if haskey(ENV, "OUTPUT_FILE")
open(ENV["OUTPUT_FILE"], "a") do io
write(io, "| ", join([VERSION; pkgversion(Turing); pkgversion(Zygote); results;], " | "), " |")
write(io, "\n")
end
end
Script for Turing >= 0.21
using Pkg
if any(Base.Fix1(haskey, ENV), ["TURING_VERSION", "ZYGOTE_VERSION"])
# In this case, we create a new env and install the corresponding package versions.
Pkg.activate(mktempdir())
if haskey(ENV, "TURING_VERSION")
TURING_VERSION = ENV["TURING_VERSION"]
@info "Trying to install Turing@$(TURING_VERSION)"
Pkg.add(name="Turing", version=TURING_VERSION)
end
if haskey(ENV, "ZYGOTE_VERSION")
ZYGOTE_VERSION = ENV["ZYGOTE_VERSION"]
@info "Trying to install Zygote@$(ZYGOTE_VERSION)"
Pkg.add(name="Zygote", version=ZYGOTE_VERSION)
end
end
using Turing, Zygote
using Turing: LogDensityProblems
if VERSION < v"1.6.2"
pkginfo(mod) = Pkg.TOML.parsefile(joinpath(dirname(Pkg.project().path), "Manifest.toml"))[string(mod)][1]
else
pkginfo(mod) = Pkg.TOML.parsefile(joinpath(dirname(Pkg.project().path), "Manifest.toml"))["deps"][string(mod)][1]
end
pkgversion(mod) = pkginfo(mod)["version"]
pkghash(mod) = pkginfo(mod)["git-tree-sha1"]
@info "Installed Turing@$(pkgversion(Turing)) [#$(pkghash(Turing))] and Zygote@$(pkgversion(Zygote)) [#$(pkghash(Zygote))]"
Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();
results = []
num_tildes = 1
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t
num_tildes = 5
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t
num_tildes = 10
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t
num_tildes = 15
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t
println(join(results, " | "))
if haskey(ENV, "OUTPUT_FILE")
open(ENV["OUTPUT_FILE"], "a") do io
write(io, "| ", join([VERSION; pkgversion(Turing); pkgversion(Zygote); results;], " | "), " |")
write(io, "\n")
end
end
Edits:
[2022-01-10 Mon 00:44]Added some preliminary results for Julia 1.6.5. Seems like nothing has changed.[2022-01-10 Mon 12:13]Added preliminary result using the-O1optimization flag. Insane difference.[2022-01-10 Mon 13:23]Added more benchmarks withjulia-1.6.3 -O1. It seems like something happened between[email protected]and[email protected]because the memory-usage on[email protected](still running) is insane (currently using all of the 64G RAM available).[2022-01-10 Mon 21:45]Added benchmarks forZygote#masterandZygote#mcabbot:opt_level. Looks likemcabbot:opt_leveldoes the trick.[2022-11-11 Fri 13:06]Added benchmarks for Julia 1.8.2, [email protected] and [email protected] + Zygote from https://github.com/FluxML/Zygote.jl/pull/1195
Thanks @torfjelde
Is there any chance this stuff gets better on 1.7 with the latest zygote? I know Turing doesn't support it yet, but does the dynamicppl let you test on 1.7?
Is there any chance this stuff gets better on 1.7 with the latest zygote? I know Turing doesn't support it yet, but does the dynamicppl let you test on 1.7?
I can but it will take a bit of work (currently using Turing functionality to compute the gradient).
Do you know when you started experiencing these issues btw?
Also, maybe some of the Zygote people have any idea what's going on here @mcabbott ? TL;DR: Compilation time of Zygote.gradient blows up wrt. number of ~ in a Turing model.
Possibly related: https://github.com/FluxML/Zygote.jl/issues/1119 and https://github.com/FluxML/Zygote.jl/issues/1126
EDIT: Seems like it. -O1 helps a lot.
Wow, the -O1 really helps. See the results in the comment above.
EDIT: Even n=20 only results in ~2min of compilation.
@torfjelde Am I reading that correctly that Julia 1.6.5 + Turing 0.19.3 + Zygote 0.6.33 brings it back to sanity?
Am I reading that correctly that Julia 1.6.5 + Turing 0.19.3 + Zygote 0.6.33 brings it back to sanity?
No, specifically you need the -O1 optimization flag (the default is -O3). It seems as if Zygote + Julia 1.6 leads to some insane compilation times when the default optimizations are used.
See if https://github.com/FluxML/Zygote.jl/pull/1147 works as well as -O1 for this purpose. If you have any other benchmarks of runtime performance (i.e. @btime) it would also be interesting to see if those get worse.
I can confirm that -O1 helps with our original problem by significantly decreasing the compilation time. Previously we have to wait for around 30-40min, and now it takes 2-3 Chopin preludes -- around 6min -- to compile. I would say this is similar to what I had in Julia 1.5.
Now the question is whether -O1 generates much less efficient code than -O3.
EDIT: preliminary experiments seem generate similar computing times.
See if FluxML/Zygote.jl#1147 works as well as
-O1for this purpose. If you have any other benchmarks of runtime performance (i.e.@btime) it would also be interesting to see if those get worse.
Will give it a go :+1:
Btw, not sure if this is more useful information , but when I try [email protected] even with -O1 the runtime blows up again (the nearest more recent version I tried which had reasonable compile time was 0.6.28).
I've never heard of Chopin preludes being used as a unit of measurement, but people should do that more often :)
Now the question is whether
-O1generates much less efficient code than-O3.
Betting on no, as most performance-sensitive stuff is gated behind a rule. Maybe scalar- or control flow-heavy code, though the generated pullbacks for the latter are likely type unstable anyhow.
but when I try
[email protected]even with-O1the runtime blows up again
Fixed by https://github.com/FluxML/Zygote.jl/pull/909 perhaps? That was in 0.6.27.
See if FluxML/Zygote.jl#1147 works as well as
-O1for this purpose. If you have any other benchmarks of runtime performance (i.e.@btime) it would also be interesting to see if those get worse.
Gave it a try; seems to do the trick! Benchmarks in table above.
Don't know what effect it has on performance though, but seems like it would be worth it.
Fixed by FluxML/Zygote.jl#909 perhaps? That was in 0.6.27.
Ah, probably! I'll give it a go.
EDIT: Seems like indeed 0.6.27 improved things significantly :+1:
Any progress on this issue by chance? Did a check with Julia 1.7 and the latest Turing, DynamicPPL, and Zygote and am still getting > 30 minute TTFG for my model with 20ish parameters. -O1 makes things more reasonable, as it did before.
Just want to make sure that everyone knows those two Zygote issues linked did not fix things.
@ToucheSir @Keno Any progress on this?
Unfortunately I have nothing concrete to report, but I have been looking into this over the past couple of months. Any help on grokking compilation latency + Zygote's internals would be much appreciated. I can't speak for Keno, but to my knowledge working on this is not on anyone else's plate.
@torfjelde I'm not able to repro your latest timings on 1.8.2 locally with the following reduced MWE:
using Turing, Zygote
using Turing: LogDensityProblems
using SnoopCompileCore
# This helps a bit, ~4s
# @eval Turing.DynamicPPL begin
# ChainRulesCore.@non_differentiable is_flagged(::VarInfo, ::VarName, ::String)
# end
Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();
num_tildes = 5
# num_tildes = 10
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
@info "starting eval"
ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
tinf = @snoopi_deep LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
@info "done eval"
using SnoopCompile, ProfileView
@show tinf
Results:
# v0.6.49
tinf = InferenceTimingNode: 38.138395/69.026222 on Core.Compiler.Timings.ROOT() with 388 direct children
# https://github.com/FluxML/Zygote.jl/pull/1195
tinf = InferenceTimingNode: 37.486037/65.298454 on Core.Compiler.Timings.ROOT() with 423 direct children
Versioninfo:
Julia Version 1.8.2
Commit 36034abf260 (2022-09-29 15:21 UTC)
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 8 × Intel(R) Core(TM) i7-4790K CPU @ 4.00GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-13.0.1 (ORCJIT, haswell)
Threads: 1 on 8 virtual cores
I had a crack at figuring out why inference times also scale so badly. All outputs of that exploration may be found in this gist.
It turns out that we can capture a good chunk of the slowness and poor scaling with just pullback(model.f, ...). SnoopCompile didn't feel very helpful because it just reports high exclusive inference time for that function, but digging into the IR with SnoopCompile's Cthulhu integration did turn up something interesting.
To demonstrate, we should first examine @code_warntype output for the un-transformed function. DynamicPPL does generate a decent amount of code, but the compiler should be able to manage 600ish statements and 100ish slots[^1] without too much trouble.
Now let's see the output for the augmented primal function from pullback. Yes, you read that correctly. Zygote generates a function with over 4,300 statements and over 19,000 slots! I have no idea why both numbers are so high, but I suspect there is something up with the IRTools IR -> Julia IR translation that happens in Zygote[^2] (Edit: IRTools.slots! is the cause of the blow-up. See the output of each IRTools pass Zygote uses for the gory details). Given the sheer amount of code, I'm not surprised that LLVM times seem to be rather horrendous as well.
So what is to be done? The first thing that comes to mind is to figure out ~~why this is happening~~ how to fix the slot explosion on the IRTools side ~~, and how much Zygote's own passes might be responsible~~. That however would require someone quite familiar with IRTools internals. The more ambitious plan would be to ditch IRTools in Zygote completely and use CodeInfo/IRCode like other libraries (including TuringLang's own LibTask) have done. Having given this a try some months back, I think the biggest missing piece is someone who understands Zygote's pre-IRTools AD transform (or the implementing AD transforms on SSA IR in general) well enough to do the port. It's not clear how a lot of the logic in Zygote now would look if passes went from working on block args to Phi nodes, for example.
[^1]: I count slots because someone mentioned before that they can consume quite a bit of memory during compilation. So having too many of them in a function may be the culprit behind the memory blowup documented here and in https://github.com/TuringLang/Turing.jl/issues/1754#issuecomment-1008460319.
[^2]: Zygote used to operate on native Julia IR, but switched to IRTools a couple of years back. This means every function goes through a native IR (really CodeInfo) -> IRTools IR -> AD transform -> native IR pipeline.
Diffractor will hopefully fix these issues, right?
Last I asked there were no plans to support setfield! (on any type) or setindex! on Dicts/RefValues, both of which Turing seems to need. So I'm not terribly optimistic...
Briefly going to comment on this to say--the solution to this issue is to use ReverseDiff or ForwardDiff.jl or (a few years down the line when it's mature) maybe some other autodiff solution like Enzyme.jl. Development on Zygote/IRTools and source-to-source AD in Julia (rather than LLVM) is effectively dead now.
To be honest, ReverseDiff has other issues and ForwardDiff is not always an option. Actually, Zygote is developed much more actively (https://github.com/FluxML/Zygote.jl/commits/master) than ReverseDiff (https://github.com/JuliaDiff/ReverseDiff.jl/commits/master) or ForwardDiff (https://github.com/JuliaDiff/ForwardDiff.jl/commits/master) (there hasn't been any new release of master since the breaking change that downstream packages did - IMO correctly - reject in a non-breaking release was reapplied to the master branch; and nobody wants to deal with and possibly fix these still existing downstream issues, hence nobody is willing to tag any new release on the ForwardDiff master branch).