AdvancedPS.jl
AdvancedPS.jl copied to clipboard
[WIP] New libtask interface
Integrate refactor from https://github.com/TuringLang/Libtask.jl/pull/179
Two things worth noting:
- Dealing with the RNG will be the user's responsability. Before
mutable struct Model <: AdvancedPS.AbstractGenericModel
mu::Float64
sig::Float64
Model() = new()
end
function (model::Model)(rng::Random.AbstractRNG)
model.sig = rand(rng, Beta(1, 1)) # AdvancedPS took care of syncing these
Libtask.produce(model.sig)
model.mu = rand(rng, Normal())
Libtask.produce(model.mu)
end
and now:
function (model::Model)()
rng = Libtask.get_dynamic_scope() # We now need to query the RNG explicitly
model.sig = rand(rng, Beta(1, 1))
Libtask.produce(model.sig)
rng = Libtask.get_dynamic_scope() # and do it everytime we want to sample random values
model.mu = rand(rng, Normal())
Libtask.produce(model.mu)
end
- How do we keep track of model state between tasks ? Pretty sure we don't want to look inside
tapedtask.fargshttps://github.com/TuringLang/AdvancedPS.jl/blob/50d493c7482a12961415cdbeacd73cca83d8b554/ext/AdvancedPSLibtaskExt.jl#L89-L91
@willtebbutt
Thanks for having a look at this!
- Dealing with the RNG will be the user's responsability. Before
Does this have any implications for integration with Turing.jl? i.e. does not passing in a RNG to the model cause any trouble downstream? (to be clear, I have no idea -- I'm not suggesting that it does / doesn't in particular)
- How do we keep track of model state between tasks ? Pretty sure we don't want to look inside tapedtask.fargs
I agree re not wanting ot dig into tapedtask.fargs. Could you elaborate a little bit on what is required here? My understanding was that task copying would handle this -- i.e. when you copy a task, all references to the model get updated, so from the perspective of the code inside the task, things just continue as normal.
As with the first item, I'm not sure exactly what the requirements are here, so I may have misunderstood something basic about what you need to do.
-
We can drop this one, that really only applies when AdvancedPS is used with Libtask outside of Turing. We will probably sunset that (or target people who supposedly know enough about Libtask)
-
Still not 100% sure about Turing but we need something like this to manage the reference particle in the Particle Gibbs loop. Here's a mvp that should replicate a simple loop of the algo:
using AdvancedPS
using Libtask
using Random
using Distributions
using SSMProblems
mutable struct Model <: AdvancedPS.AbstractGenericModel
x::Float64
y::Float64
Model() = new()
end
function (model::Model)()
rng = Libtask.get_dynamic_scope()
model.x = rand(rng, Beta(1,1))
Libtask.produce(model.x)
rng = Libtask.get_dynamic_scope()
model.y = rand(rng, Normal(0, model.x))
Libtask.produce(model.y)
end
rng = AdvancedPS.TracedRNG()
Random.seed!(rng, 10)
model = Model()
trace = AdvancedPS.Trace(model, rng)
# Sample `x`
AdvancedPS.advance!(trace)
trace2 = AdvancedPS.fork(trace)
key = AdvancedPS.state(trace.rng.rng)
seeds = AdvancedPS.split(key, 2)
Random.seed!(trace.rng, seeds[1])
Random.seed!(trace2.rng, seeds[2])
# Inherit `x` across independent particles
AdvancedPS.advance!(trace)
AdvancedPS.advance!(trace2)
println("Parent particle")
println(trace.model.f)
println("Child particle")
println(trace2.model.f)
println("Model with actual sampled values is in ctask.fargs")
println(trace2.model.ctask.fargs[1])
# Create reference particle
# Suppose we select the previous 'child' particle
ref = AdvancedPS.forkr(trace2)
println("Did we keep all the generated values ?")
println(ref.model.f) # If we just copy the tapedtask, we don't get the sampled values in the `Model`
# Note, this is only a problem when creating a reference trajectory,
# sampled values are properly captured during the execution of the task
println(ref.model.f) # If we just copy the tapedtask, we don't get the sampled values in the
ModelNote, this is only a problem when creating a reference trajectory,
@FredericWantiez can we store trace.rng inside TapedTask instead of trace? That way, when copying a TapedTask, we will copy the trace.rng.
@willtebbutt I think 2) might also be a problem for Turing, when looking at this part: https://github.com/TuringLang/Turing.jl/blob/afb5c44d6dc1736831f45620328c9d5681748111/src/mcmc/particle_mcmc.jl#L140-L142
Two small issues I found cleaning up the tests.
Libtask returns a value after the last produce statement:
function f()
Libtask.produce(1)
Libtask.produce(2)
end
t1 = TapedTask(nothing, f)
consume(t1) # 1
consume(t1) # 2
consume(t2) # 2 (?)
Libtask doesn't catch some of the produce statements:
mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
a::Float64
b::Float64
NormalModel() = new()
end
function (m::NormalModel)()
# First latent variable.
rng = Libtask.get_dynamic_scope()
m.a = a = rand(rng, Normal(4, 5))
# First observation.
AdvancedPS.observe(Normal(a, 2), 3)
# Second latent variable.
rng = Libtask.get_dynamic_scope()
m.b = b = rand(rng, Normal(a, 1))
# Second observation.
AdvancedPS.observe(Normal(b, 2), 1.5)
return nothing
end
rng = AdvancedPS.TracedRNG()
t = TapedTask(rng, NormalModel())
consume(t) # some float
consume(t) # 0 (?)
consume(t) # 0 (?)
this works fine if I call Libtask.produce explicitly instead of observe
EDIT: Changing observe to something like this seems to work:
function AdvancedPS.observe(dist::Distributions.Distribution, x)
Libtask.produce(Distributions.loglikelihood(dist, x))
return nothing
end
If we store both rng and varinfo in the scoped variable, then the following suggestions will address (2):
- store
varinfoin theTracestruct, then change here toLibtask.set_dynamic_scope!(trace.model.ctask, (trace.rng, trace.varinfo)) - change here and here to
rng, varinfo = Libtask.get_dynamic_scope() - change here to
transition = SMCTransition(model, particle.varinfo, weight)
That should work, I have a branch against Turing that tries to do this but seems like one copy is not quite correct.
The other solution is to use one replay step before the transition, to repopulate the varinfo properly:
new_particle = AdvancedPS.replay(particle)
transition = SMCTransition(model, new_particle.model.f.varinfo, weight)
state = SMCState(particles, 2, logevidence)
return transition, state
@willtebbutt running models against this PR I see a large performance drop:
using Libtask
using AdvancedPS
using Distributions
using Random
mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
a::Float64
b::Float64
NormalModel() = new()
end
function (m::NormalModel)()
# First latent variable.
rng = Libtask.get_dynamic_scope()
m.a = a = rand(rng, Normal(4, 5))
# First observation.
AdvancedPS.observe(Normal(a, 2), 3)
# Second latent variable.
rng = Libtask.get_dynamic_scope()
m.b = b = rand(rng, Normal(a, 1))
# Second observation.
AdvancedPS.observe(Normal(b, 2), 1.5)
end
@time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false)
On master:
1.816623 seconds (5.92 M allocations: 311.647 MiB, 1.52% gc time, 96.09% compilation time)
On this PR:
72.085056 seconds (369.62 M allocations: 17.322 GiB, 2.83% gc time, 77.21% compilation time)
Thanks for the data point. Essentially the final item on my todo list is sorting out various type inference issues in the current implementation. Once they're done, we should see substantially improved performance.
That should work, I have a branch against Turing that tries to do this but seems like one copy is not quite correct.
The varinfo variable is updated during inference. I think we have to carefully ensure the correct varinfo is stored in the scoped variable.
cc @mhauru @FredericWantiez
@willtebbutt running models against this PR I see a large performance drop:
@FredericWantiez I'm finally looking at sorting out the performance of the Libtask updates. I'm struggling to replicate the performance of your example on the current versions of packages, because I find that it errors. My environment is
(jl_4fXu3W) pkg> st
Status `/private/var/folders/z7/0fkyw8ms795b7znc_3vbvrsw0000gn/T/jl_4fXu3W/Project.toml`
[576499cb] AdvancedPS v0.6.1
[31c24e10] Distributions v0.25.118
[6f1fad26] Libtask v0.8.8
[9a3f8284] Random v1.11.0
I tried it on LTS and 1.11.4.
In particular, I'm seeing the error:
ERROR: BoundsError: attempt to access 0-element Vector{Any} at index [1]
Stacktrace:
[1] throw_boundserror(A::Vector{Any}, I::Tuple{Int64})
@ Base ./essentials.jl:14
[2] getindex
@ ./essentials.jl:916 [inlined]
[3] _infer(f::NormalModel, args_type::Tuple{DataType})
@ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:45
[4] Libtask.TapedFunction{…}(f::NormalModel, args::AdvancedPS.TracedRNG{…}; cache::Bool, deepcopy_types::Type)
@ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:72
[5] TapedFunction
@ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:62 [inlined]
[6] _
@ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
[7] TapedFunction
@ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
[8] #TapedTask#15
@ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:76 [inlined]
[9] TapedTask
@ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:70 [inlined]
[10] LibtaskModel
@ ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:27 [inlined]
[11] AdvancedPS.Trace(::NormalModel, ::AdvancedPS.TracedRNG{UInt64, 1, Random123.Philox2x{UInt64, 10}})
@ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:49
[12] (::AdvancedPSLibtaskExt.var"#2#3"{NormalModel, Nothing, Bool, Int64})(i::Int64)
@ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:140
[13] iterate
@ ./generator.jl:48 [inlined]
[14] _collect(c::UnitRange{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
@ Base ./array.jl:811
[15] collect_similar
@ ./array.jl:720 [inlined]
[16] map
@ ./abstractarray.jl:3371 [inlined]
[17] step(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, state::Nothing; kwargs::@Kwargs{})
@ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:134
[18] macro expansion
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:0 [inlined]
[19] macro expansion
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/logging.jl:16 [inlined]
[20] mcmcsample(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, num_warmup::Int64, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:142
[21] mcmcsample
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:107 [inlined]
[22] #sample#20
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:59 [inlined]
[23] sample
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:52 [inlined]
[24] #sample#19
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:21 [inlined]
[25] sample
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:18 [inlined]
[26] macro expansion
@ ./timing.jl:581 [inlined]
[27] top-level scope
@ ./REPL[10]:1
Some type information was truncated. Use `show(err)` to see complete types.
Any idea whether I'm doing something wrong?
But, additionally, the latest version of the PR should address the various performance issues we previously had. There is one important change though: you need to pass a type to Libtask.get_dynamic_scope, which should be the type of the thing that it's going to return. We need this because there's no way to make the container typed (I assume that the previous implementation had a similar limitation). The docstring has been updated to reflect the changes.
@willtebbutt if you're testing against the released version of Libtask/AdvancedPS you need to explicitly pass the RNG in the model definition, something like that:
function (model::Model)(rng::Random.AbstractRNG) # Add the RNG as argument
model.sig = rand(rng, Beta(1, 1))
Libtask.produce(model.sig)
model.mu = rand(rng, Normal())
Libtask.produce(model.mu)
end
This now runs faster with AdvancedPS (dc5e594) and Libtask (8e7f784)
# run once to triger compilation
julia> @time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false);
2.986750 seconds (7.31 M allocations: 380.449 MiB, 0.88% gc time, 99.51% compilation time)
# second time runs faster
julia> @time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false);
0.012714 seconds (32.85 k allocations: 18.581 MiB, 19.87% gc time)
Code
(@temp) pkg> add AdvancedPS#fred/libtask-revamp
(@temp) pkg> add Libtask#wct/refactor
using Libtask
using AdvancedPS
using Distributions
using Random
mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
a::Float64
b::Float64
NormalModel() = new()
end
function (m::NormalModel)()
# First latent variable.
T = AdvancedPS.TracedRNG{UInt64, 1, AdvancedPS.Random123.Philox2x{UInt64, 10}};
rng = Libtask.Libtask.get_taped_globals(T)
m.a = a = rand(rng, Normal(4, 5))
# First observation.
AdvancedPS.observe(Normal(a, 2), 3)
# Second latent variable.
rng = Libtask.Libtask.get_taped_globals(T)
m.b = b = rand(rng, Normal(a, 1))
# Second observation.
AdvancedPS.observe(Normal(b, 2), 1.5)
end
@time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false)
New Libtask has now been released. CI failures reveal that a few more fixes are required from AdvancedPS.
AdvancedPS.jl documentation for PR #114 is available at: https://TuringLang.github.io/AdvancedPS.jl/previews/PR114/
Most of the test failures seemed to be a case renaming a function (although the opaque error message "Unbound GlobalRef not allowed in value position" that they were yielding concerns me somewhat).
The remaining test failure is about addreference!/current_trace, which I'm a bit confused about. I thought we only wanted to store in the task local storage/taped_globals the RNG. However, addreference! stores the whole trace object. In Turing.jl this is used to get the RNG from the trace (something we can replace with getting it from taped_globals), but also to get a VarInfo from the trace. See these lines in Turing.jl's mcmc/particle_mcmc.jl:
function trace_local_varinfo_maybe(varinfo)
try
trace = AdvancedPS.current_trace()
return trace.model.f.varinfo
catch e
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
if e == KeyError(:__trace) || current_task().storage isa Nothing
return varinfo
else
rethrow(e)
end
end
end
and
function DynamicPPL.assume(
rng,
spl::Sampler{<:Union{PG,SMC}},
dist::Distribution,
vn::VarName,
_vi::AbstractVarInfo,
)
vi = trace_local_varinfo_maybe(_vi)
[...]
Why do we need to do this? If we do need to do this, do we have to rework our use of taped_globals to store not only an RNG but also a VarInfo?
Briefly, addreference!/current_trace can be safely removed. These are replaced by set_taped_globals/ get_taped_globals.
(A full explanation requires me explaining the history of Libtask...)
EDIT: (rng, trace.model.f.varinfo) can be saved to TapedTask directly using set_taped_globals. This replaces the old design of storing them in task local storage and keeping a reference to task in each TapedTask (i.e. addreference!)
I've just started pushing onto your PR @FredericWantiez, I hope you don't mind. I can make a separate PR if you prefer.
Tests should now pass, but please don't review or merge yet. I don't trust that I've done this right until I see it work with Turing.jl. I'll try to get that done locally now.
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Please upload report for BASE (
main@1ad89ec). Learn more about missing BASE report.
Additional details and impacted files
@@ Coverage Diff @@
## main #114 +/- ##
=======================================
Coverage ? 96.27%
=======================================
Files ? 8
Lines ? 429
Branches ? 0
=======================================
Hits ? 413
Misses ? 16
Partials ? 0
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Could I get some reviews on this please? I've now got some of Turing's ParticleGibbs tests passing locally using this, which gives me more confidence that it's largely correct (one bug was found in the process). Making a release of this would help run the full test suite of Turing.jl and see if anything else comes up.
I would request @FredericWantiez's review, but you are the PR owner so I can't.
Thanks @mhauru, I did a careful pass. It looks correct to me. Let's simplify more here so the code becomes less mysterious.
Thanks @yebai. Sorry about the general messiness here, since my understanding of the code is poor I only tried to do the minimal edits needed to get it to work.
Sorry about the general messiness here, since my understanding of the code is poor I only tried to do the minimal edits needed to get it to work.
No worries, I am probably the one to blame for the messy code. My suggestions only encourage you to do more to simplify the code now that some heuristics have become unnecessary.
I've removed the addreference! function by having the backreference added when a LibtaskTrace object is created. This makes it harder to forget to add it, and makes our interface simpler. The backreference itself still remains, because we use it to access the VarInfo of the trace. While it's true that we store the trace in the TapedGlobals.other field only to have access to trace.model.f.varinfo, we unfortunately can't simplify that by storing the VarInfo directly, because Turing isn't a dependency of AdvancedPS and thus we have no visibility into things like the structure of a DPPL.Model object.