Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

Zygote's compilation scales badly with the number of `~` statements

Open torfjelde opened this issue 3 years ago • 33 comments

I'm not certain whether or not this can be considered a "Turing.jl-issue" or not, but I figured I would at least raise it as an issue here so people are aware.

The compilation time of Zygote scales quite badly with the number of ~ statements.

TL;DR: it takes almost 5 minutes to compile a model with 14 ~ statements. I don't have the result here, but at some point I tried one with 20 ~ statements, and it took a full ~23 mins to compile.

Demo

using Turing, Zygote
Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();
results = []
num_tildes = 0

Running the following snippet a couple of times we get a sense of the compilation times:

num_tildes += 1

Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

t = @elapsed Turing.Core.gradient_logp(
   adbackend,
   vi[spl],
   vi,
   model,
   spl
);
push!(results, t)
results
14-element Vector{Any}:
  15.917412882
   9.600213651
  14.911253238
  24.699050206
  71.811476745
  49.158314248
  46.059394601
  57.44494627
  75.514564551
  94.927956369
 134.165383535
 156.079943416
 202.745837585
 273.93479382

That is, it takes almost 5 minutes to compile a model with 14 ~ statements. I don't have the result here, but at some point I tried one with 20 ~ statements, and it took a full ~23 mins to compile.

Additional info

versioninfo()
Julia Version 1.6.2
Commit 1b93d53fc4 (2021-07-14 15:36 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-10710U CPU @ 1.10GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
Pkg.status()
    Status `/tmp/jl_PHYdgF/Project.toml`
[fce5fe82] Turing v0.19.2
[e88e6eb3] Zygote v0.6.32

torfjelde avatar Dec 21 '21 22:12 torfjelde

Thanks a lot for posting this!

wupeifan avatar Dec 21 '21 22:12 wupeifan

Do we have code that checks the compile times of models? How long has Zygote compilation been taking this long?

ParadaCarleton avatar Dec 25 '21 21:12 ParadaCarleton

How long has Zygote compilation been taking this long?

Since Julia 1.6 afaik.

wupeifan avatar Dec 25 '21 22:12 wupeifan

How long has Zygote compilation been taking this long?

Since Julia 1.6 afaik.

As in, the 1.6 update caused it? Is the compilation faster on 1.5?

ParadaCarleton avatar Dec 26 '21 17:12 ParadaCarleton

@torfjelde probably can get the right answer, if it matters. There are a lot of packages and versions interacting so not sure it necessarily matches a particular Julia version. It might take just as long to find the exact combination of Julia, zygote, chainrules, Turing, dynamicppl, etc versions that caused it as it would to actually solve the problem.

jlperla avatar Dec 26 '21 17:12 jlperla

As in, the 1.6 update caused it? Is the compilation faster on 1.5?

I don't think so. It's only my experience instead of any benchmarking/profiling (which we need!) I don't have a good answer to this question.

wupeifan avatar Dec 26 '21 18:12 wupeifan

Sorry to bump so soon @torfjelde but we have any sense of a timeline on this? Are we thinking a week, a month, etc.? We would help out but I fear this is a little too tightly connected to the PPL macros for us to contribute to.

jlperla avatar Jan 07 '22 19:01 jlperla

I'm sorry, I can't put a timeline on this right now.

And I worry and suspect it's not particularly related to Turing, unfortunately :confused: There was one change we made in Turing that I worried could have caused it, but when I tried with an older version the performance was still horrible.

The best way to identify what's wrong is to just run the test above for different versions of Turing and Zygote.

torfjelde avatar Jan 07 '22 23:01 torfjelde

As of right now I have the following results:

Julia Args Turing Zygote 1 5 10 15 20
1.6.1 0.19 0.6 16.387332656 89.504752487 101.654405869 320.501781565 N/A
1.6.1 0.18 0.6 15.740250148 72.75134579 92.373805461 310.690419646 N/A
1.6.1 0.16 0.6 14.098175146 77.075566728 86.093118485 310.681374129 N/A
1.6.1 0.15 0.6 13.674902932 75.142343745 83.177268954 309.479831351 N/A
1.6.1 0.19.3 0.6.32 16.519126809 76.212964076 102.056456525 331.835923261 N/A
1.6.1 0.19.3 0.6.30 16.416369038 73.552713013 98.009949908 325.177503251 N/A
1.6.1 0.19.3 0.6.28 15.907693813 73.859609986 99.65385215 320.406822231 N/A
1.6.1 0.19.3 0.6.25 16.155023105 78.113216025 91.386990818 323.182170592 N/A
1.6.1 0.19.3 0.6.20 16.18978755 76.922928732 92.863883057 329.472662078 N/A
1.6.1 0.19.3 0.6.17 16.674484174 45.536406584 91.975029367 324.107966298 N/A
1.6.1 0.19.3 0.6.15 16.647893857 77.246173583 92.547574565 325.843565657 N/A
1.6.5 0.19.3 0.6.32 16.507995835 75.232973774 102.738249637 320.804476864 N/A
1.6.5 0.19.3 0.6.30 16.442945162 75.552669923 100.079247493 315.767280417 N/A
1.6.5 0.19.3 0.6.28 15.696215339 41.383477294 93.911315279 329.02145826 N/A
1.6.5 -O1 0.19.3 0.6.27 9.442187946 13.78550593 27.309818329 61.662899872
1.6.5 -O1 0.19.3 0.6.26 9.525658865 15.737620623 37.788878332 154.77089758 N/A
1.6.5 0.19.3 0.6.25 16.103407208 78.827783495 92.628645679 323.749846052 N/A
1.6.5 0.19.3 0.6.20 16.081257806 45.193458428 91.298156568 324.887103193 N/A
1.6.5 0.19.3 0.6.17 16.661767166 117.089430786 91.768988429 322.823669619 N/A
1.6.5 0.19.3 0.6.15 16.698116787 99.878925099 91.444128363 322.810816215 N/A
1.6.5 -O1 0.19.3 0.6.33 9.949009976 14.697673851 29.464735085 66.522595927 132.225028216
1.6.5 -O1 0.19.3 0.6.33 9.952505584 14.642345133 29.487756023 66.405126925 132.858871417
1.6.5 -O1 0.19.3 0.6.32 10.213924716 14.90292511 29.43141585 65.774123029 132.584280038
1.6.5 -O1 0.19.3 0.6.30 9.952360792 14.604632874 29.263546539 65.256550668 131.766387963
1.6.5 -O1 0.19.3 0.6.28 9.61100625 14.16199726 28.106580684 61.752193716 128.178242579
1.6.5 0.19.3 #master 16.492658709 76.261172697 99.347007007 318.410333441
1.6.5 0.19.3 mcabbot:opt_level 14.510376473 15.581238372 30.034043196 67.305733329
1.8.2 0.21.13 0.6.49 (#66cc60) 16.799219872 48.980447826 182.013203064 1359.453378186 N/A
1.8.2 0.21.13 0.6.49 (#ee8945) 32.313442694 104.451519939 284.157696506 N/A N/A

Columns with numeric values represents the number of ~ statements.

This is on the following system:

julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-6850K CPU @ 3.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, broadwell)

The sad news is that I can't really glean anything useful from the above :confused: Here it seems to be "fine" (taking 5mins to compile is not good, but it's much better than the reported numbers).

Comments:

  • When I tried running this benchmark on my laptop (which has 32G memory), Zygote earlier than 0.6.17 blew up (on Julia 1.6.3), i.e. it seems as if the memory-usage of Zygote changed significantly from 0.6.17 to 0.6.18?
  • The version columns which are missing patch-version, e.g. 0.19, are using the most recent version of Zygote which is compatible with the corresponding Turing-version. I can find these out if need be (the numbers are from an earlier version of the script I'm using).

[2022-01-10 Mon 00:36]: I'm now running the same benchmarks on Julia 1.6.5 just to make sure it has nothing to do with weird interactions between the particular Julia version and Zygote.

Script I'm running
using Pkg; Pkg.activate(mktempdir())
TURING_VERSION = ENV["TURING_VERSION"]
ZYGOTE_VERSION = ENV["ZYGOTE_VERSION"]

@info "Trying to install Turing@$(TURING_VERSION) and Zygote@$(ZYGOTE_VERSION)"
Pkg.add(name="Turing", version=TURING_VERSION)
Pkg.add(name="Zygote", version=ZYGOTE_VERSION)

using Turing, Zygote

pkgversion(mod) = Pkg.TOML.parsefile(joinpath(dirname(Pkg.project().path), "Manifest.toml"))[string(mod)][1]["version"]
@info "Installed Turing@$(pkgversion(Turing)) and Zygote@$(pkgversion(Zygote))"

Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();

results = []

num_tildes = 1
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

t = @elapsed Turing.Core.gradient_logp(
   adbackend,
   vi[spl],
   vi,
   model,
   spl
);
push!(results, t)
@info "Result" t

num_tildes = 5
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

t = @elapsed Turing.Core.gradient_logp(
   adbackend,
   vi[spl],
   vi,
   model,
   spl
);
push!(results, t)
@info "Result" t

num_tildes = 10
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

t = @elapsed Turing.Core.gradient_logp(
   adbackend,
   vi[spl],
   vi,
   model,
   spl
);
push!(results, t)
@info "Result" t

num_tildes = 15
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

t = @elapsed Turing.Core.gradient_logp(
   adbackend,
   vi[spl],
   vi,
   model,
   spl
);
push!(results, t)
@info "Result" t

println(join(results, " | "))

if haskey(ENV, "OUTPUT_FILE")
    open(ENV["OUTPUT_FILE"], "a") do io
        write(io, "| ", join([VERSION; pkgversion(Turing); pkgversion(Zygote); results;], " | "), " |")
        write(io, "\n")
    end
end
Script for Turing >= 0.21
using Pkg

if any(Base.Fix1(haskey, ENV), ["TURING_VERSION", "ZYGOTE_VERSION"])
    # In this case, we create a new env and install the corresponding package versions.
    Pkg.activate(mktempdir())

    if haskey(ENV, "TURING_VERSION")
        TURING_VERSION = ENV["TURING_VERSION"]
        @info "Trying to install Turing@$(TURING_VERSION)"
        Pkg.add(name="Turing", version=TURING_VERSION)
    end

    if haskey(ENV, "ZYGOTE_VERSION")
        ZYGOTE_VERSION = ENV["ZYGOTE_VERSION"]
        @info "Trying to install Zygote@$(ZYGOTE_VERSION)"
        Pkg.add(name="Zygote", version=ZYGOTE_VERSION)
    end
end

using Turing, Zygote
using Turing: LogDensityProblems


if VERSION < v"1.6.2"
    pkginfo(mod) = Pkg.TOML.parsefile(joinpath(dirname(Pkg.project().path), "Manifest.toml"))[string(mod)][1]
else
    pkginfo(mod) = Pkg.TOML.parsefile(joinpath(dirname(Pkg.project().path), "Manifest.toml"))["deps"][string(mod)][1]
end

pkgversion(mod) = pkginfo(mod)["version"]
pkghash(mod) = pkginfo(mod)["git-tree-sha1"]

@info "Installed Turing@$(pkgversion(Turing)) [#$(pkghash(Turing))] and Zygote@$(pkgversion(Zygote)) [#$(pkghash(Zygote))]"

Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();

results = []

num_tildes = 1
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t

num_tildes = 5
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t

num_tildes = 10
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t

num_tildes = 15
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t

println(join(results, " | "))

if haskey(ENV, "OUTPUT_FILE")
    open(ENV["OUTPUT_FILE"], "a") do io
        write(io, "| ", join([VERSION; pkgversion(Turing); pkgversion(Zygote); results;], " | "), " |")
        write(io, "\n")
    end
end

Edits:

  • [2022-01-10 Mon 00:44] Added some preliminary results for Julia 1.6.5. Seems like nothing has changed.
  • [2022-01-10 Mon 12:13] Added preliminary result using the -O1 optimization flag. Insane difference.
  • [2022-01-10 Mon 13:23] Added more benchmarks with julia-1.6.3 -O1. It seems like something happened between [email protected] and [email protected] because the memory-usage on [email protected] (still running) is insane (currently using all of the 64G RAM available).
  • [2022-01-10 Mon 21:45] Added benchmarks for Zygote#master and Zygote#mcabbot:opt_level. Looks like mcabbot:opt_level does the trick.
  • [2022-11-11 Fri 13:06] Added benchmarks for Julia 1.8.2, [email protected] and [email protected] + Zygote from https://github.com/FluxML/Zygote.jl/pull/1195

torfjelde avatar Jan 10 '22 00:01 torfjelde

Thanks @torfjelde

Is there any chance this stuff gets better on 1.7 with the latest zygote? I know Turing doesn't support it yet, but does the dynamicppl let you test on 1.7?

jlperla avatar Jan 10 '22 01:01 jlperla

Is there any chance this stuff gets better on 1.7 with the latest zygote? I know Turing doesn't support it yet, but does the dynamicppl let you test on 1.7?

I can but it will take a bit of work (currently using Turing functionality to compute the gradient).

Do you know when you started experiencing these issues btw?

Also, maybe some of the Zygote people have any idea what's going on here @mcabbott ? TL;DR: Compilation time of Zygote.gradient blows up wrt. number of ~ in a Turing model.

torfjelde avatar Jan 10 '22 11:01 torfjelde

Possibly related: https://github.com/FluxML/Zygote.jl/issues/1119 and https://github.com/FluxML/Zygote.jl/issues/1126

EDIT: Seems like it. -O1 helps a lot.

torfjelde avatar Jan 10 '22 12:01 torfjelde

Wow, the -O1 really helps. See the results in the comment above.

EDIT: Even n=20 only results in ~2min of compilation.

torfjelde avatar Jan 10 '22 12:01 torfjelde

@torfjelde Am I reading that correctly that Julia 1.6.5 + Turing 0.19.3 + Zygote 0.6.33 brings it back to sanity?

jlperla avatar Jan 10 '22 17:01 jlperla

Am I reading that correctly that Julia 1.6.5 + Turing 0.19.3 + Zygote 0.6.33 brings it back to sanity?

No, specifically you need the -O1 optimization flag (the default is -O3). It seems as if Zygote + Julia 1.6 leads to some insane compilation times when the default optimizations are used.

torfjelde avatar Jan 10 '22 20:01 torfjelde

See if https://github.com/FluxML/Zygote.jl/pull/1147 works as well as -O1 for this purpose. If you have any other benchmarks of runtime performance (i.e. @btime) it would also be interesting to see if those get worse.

mcabbott avatar Jan 10 '22 20:01 mcabbott

I can confirm that -O1 helps with our original problem by significantly decreasing the compilation time. Previously we have to wait for around 30-40min, and now it takes 2-3 Chopin preludes -- around 6min -- to compile. I would say this is similar to what I had in Julia 1.5. Now the question is whether -O1 generates much less efficient code than -O3.

EDIT: preliminary experiments seem generate similar computing times.

wupeifan avatar Jan 10 '22 21:01 wupeifan

See if FluxML/Zygote.jl#1147 works as well as -O1 for this purpose. If you have any other benchmarks of runtime performance (i.e. @btime) it would also be interesting to see if those get worse.

Will give it a go :+1:

Btw, not sure if this is more useful information , but when I try [email protected] even with -O1 the runtime blows up again (the nearest more recent version I tried which had reasonable compile time was 0.6.28).

torfjelde avatar Jan 10 '22 21:01 torfjelde

I've never heard of Chopin preludes being used as a unit of measurement, but people should do that more often :)

Now the question is whether -O1 generates much less efficient code than -O3.

Betting on no, as most performance-sensitive stuff is gated behind a rule. Maybe scalar- or control flow-heavy code, though the generated pullbacks for the latter are likely type unstable anyhow.

but when I try [email protected] even with -O1 the runtime blows up again

Fixed by https://github.com/FluxML/Zygote.jl/pull/909 perhaps? That was in 0.6.27.

ToucheSir avatar Jan 10 '22 21:01 ToucheSir

See if FluxML/Zygote.jl#1147 works as well as -O1 for this purpose. If you have any other benchmarks of runtime performance (i.e. @btime) it would also be interesting to see if those get worse.

Gave it a try; seems to do the trick! Benchmarks in table above.

Don't know what effect it has on performance though, but seems like it would be worth it.

torfjelde avatar Jan 10 '22 21:01 torfjelde

Fixed by FluxML/Zygote.jl#909 perhaps? That was in 0.6.27.

Ah, probably! I'll give it a go.

EDIT: Seems like indeed 0.6.27 improved things significantly :+1:

torfjelde avatar Jan 10 '22 21:01 torfjelde

Any progress on this issue by chance? Did a check with Julia 1.7 and the latest Turing, DynamicPPL, and Zygote and am still getting > 30 minute TTFG for my model with 20ish parameters. -O1 makes things more reasonable, as it did before.

Just want to make sure that everyone knows those two Zygote issues linked did not fix things.

jlperla avatar Mar 02 '22 00:03 jlperla

@ToucheSir @Keno Any progress on this?

ParadaCarleton avatar Jun 01 '22 15:06 ParadaCarleton

Unfortunately I have nothing concrete to report, but I have been looking into this over the past couple of months. Any help on grokking compilation latency + Zygote's internals would be much appreciated. I can't speak for Keno, but to my knowledge working on this is not on anyone else's plate.

ToucheSir avatar Jun 01 '22 15:06 ToucheSir

@torfjelde I'm not able to repro your latest timings on 1.8.2 locally with the following reduced MWE:

using Turing, Zygote
using Turing: LogDensityProblems
using SnoopCompileCore

# This helps a bit, ~4s
# @eval Turing.DynamicPPL begin
#   ChainRulesCore.@non_differentiable is_flagged(::VarInfo, ::VarName, ::String)
# end

Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();

num_tildes = 5
# num_tildes = 10

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

@info "starting eval"
ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
tinf = @snoopi_deep LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
@info "done eval"

using SnoopCompile, ProfileView
@show tinf

Results:

# v0.6.49
tinf = InferenceTimingNode: 38.138395/69.026222 on Core.Compiler.Timings.ROOT() with 388 direct children
# https://github.com/FluxML/Zygote.jl/pull/1195
tinf = InferenceTimingNode: 37.486037/65.298454 on Core.Compiler.Timings.ROOT() with 423 direct children

Versioninfo:

Julia Version 1.8.2
Commit 36034abf260 (2022-09-29 15:21 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i7-4790K CPU @ 4.00GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, haswell)
  Threads: 1 on 8 virtual cores

ToucheSir avatar Nov 11 '22 18:11 ToucheSir

I had a crack at figuring out why inference times also scale so badly. All outputs of that exploration may be found in this gist.

It turns out that we can capture a good chunk of the slowness and poor scaling with just pullback(model.f, ...). SnoopCompile didn't feel very helpful because it just reports high exclusive inference time for that function, but digging into the IR with SnoopCompile's Cthulhu integration did turn up something interesting.

To demonstrate, we should first examine @code_warntype output for the un-transformed function. DynamicPPL does generate a decent amount of code, but the compiler should be able to manage 600ish statements and 100ish slots[^1] without too much trouble.

Now let's see the output for the augmented primal function from pullback. Yes, you read that correctly. Zygote generates a function with over 4,300 statements and over 19,000 slots! I have no idea why both numbers are so high, but I suspect there is something up with the IRTools IR -> Julia IR translation that happens in Zygote[^2] (Edit: IRTools.slots! is the cause of the blow-up. See the output of each IRTools pass Zygote uses for the gory details). Given the sheer amount of code, I'm not surprised that LLVM times seem to be rather horrendous as well.

So what is to be done? The first thing that comes to mind is to figure out ~~why this is happening~~ how to fix the slot explosion on the IRTools side ~~, and how much Zygote's own passes might be responsible~~. That however would require someone quite familiar with IRTools internals. The more ambitious plan would be to ditch IRTools in Zygote completely and use CodeInfo/IRCode like other libraries (including TuringLang's own LibTask) have done. Having given this a try some months back, I think the biggest missing piece is someone who understands Zygote's pre-IRTools AD transform (or the implementing AD transforms on SSA IR in general) well enough to do the port. It's not clear how a lot of the logic in Zygote now would look if passes went from working on block args to Phi nodes, for example.

[^1]: I count slots because someone mentioned before that they can consume quite a bit of memory during compilation. So having too many of them in a function may be the culprit behind the memory blowup documented here and in https://github.com/TuringLang/Turing.jl/issues/1754#issuecomment-1008460319. [^2]: Zygote used to operate on native Julia IR, but switched to IRTools a couple of years back. This means every function goes through a native IR (really CodeInfo) -> IRTools IR -> AD transform -> native IR pipeline.

ToucheSir avatar Nov 12 '22 05:11 ToucheSir

Diffractor will hopefully fix these issues, right?

yebai avatar Nov 12 '22 20:11 yebai

Last I asked there were no plans to support setfield! (on any type) or setindex! on Dicts/RefValues, both of which Turing seems to need. So I'm not terribly optimistic...

ToucheSir avatar Nov 12 '22 20:11 ToucheSir

Briefly going to comment on this to say--the solution to this issue is to use ReverseDiff or ForwardDiff.jl or (a few years down the line when it's mature) maybe some other autodiff solution like Enzyme.jl. Development on Zygote/IRTools and source-to-source AD in Julia (rather than LLVM) is effectively dead now.

ParadaCarleton avatar Oct 20 '23 06:10 ParadaCarleton

To be honest, ReverseDiff has other issues and ForwardDiff is not always an option. Actually, Zygote is developed much more actively (https://github.com/FluxML/Zygote.jl/commits/master) than ReverseDiff (https://github.com/JuliaDiff/ReverseDiff.jl/commits/master) or ForwardDiff (https://github.com/JuliaDiff/ForwardDiff.jl/commits/master) (there hasn't been any new release of master since the breaking change that downstream packages did - IMO correctly - reject in a non-breaking release was reapplied to the master branch; and nobody wants to deal with and possibly fix these still existing downstream issues, hence nobody is willing to tag any new release on the ForwardDiff master branch).

devmotion avatar Oct 20 '23 07:10 devmotion