AdvancedPS.jl icon indicating copy to clipboard operation
AdvancedPS.jl copied to clipboard

[WIP] New libtask interface

Open FredericWantiez opened this issue 8 months ago • 4 comments

Integrate refactor from https://github.com/TuringLang/Libtask.jl/pull/179

Two things worth noting:

  1. Dealing with the RNG will be the user's responsability. Before
mutable struct Model <: AdvancedPS.AbstractGenericModel
  mu::Float64
  sig::Float64

  Model() = new()
end


function (model::Model)(rng::Random.AbstractRNG)
  model.sig = rand(rng, Beta(1, 1))  # AdvancedPS took care of syncing these
  Libtask.produce(model.sig)

  model.mu = rand(rng, Normal())
  Libtask.produce(model.mu)
end

and now:

function (model::Model)()
  rng = Libtask.get_dynamic_scope() # We now need to query the RNG explicitly
  model.sig = rand(rng, Beta(1, 1))
  Libtask.produce(model.sig)

  rng = Libtask.get_dynamic_scope() # and do it everytime we want to sample random values
  model.mu = rand(rng, Normal())
  Libtask.produce(model.mu)
end
  1. How do we keep track of model state between tasks ? Pretty sure we don't want to look inside tapedtask.fargs https://github.com/TuringLang/AdvancedPS.jl/blob/50d493c7482a12961415cdbeacd73cca83d8b554/ext/AdvancedPSLibtaskExt.jl#L89-L91

@willtebbutt

FredericWantiez avatar Mar 23 '25 11:03 FredericWantiez

Thanks for having a look at this!

  1. Dealing with the RNG will be the user's responsability. Before

Does this have any implications for integration with Turing.jl? i.e. does not passing in a RNG to the model cause any trouble downstream? (to be clear, I have no idea -- I'm not suggesting that it does / doesn't in particular)

  1. How do we keep track of model state between tasks ? Pretty sure we don't want to look inside tapedtask.fargs

I agree re not wanting ot dig into tapedtask.fargs. Could you elaborate a little bit on what is required here? My understanding was that task copying would handle this -- i.e. when you copy a task, all references to the model get updated, so from the perspective of the code inside the task, things just continue as normal.

As with the first item, I'm not sure exactly what the requirements are here, so I may have misunderstood something basic about what you need to do.

willtebbutt avatar Mar 25 '25 07:03 willtebbutt

  1. We can drop this one, that really only applies when AdvancedPS is used with Libtask outside of Turing. We will probably sunset that (or target people who supposedly know enough about Libtask)

  2. Still not 100% sure about Turing but we need something like this to manage the reference particle in the Particle Gibbs loop. Here's a mvp that should replicate a simple loop of the algo:

using AdvancedPS
using Libtask
using Random
using Distributions
using SSMProblems


mutable struct Model <: AdvancedPS.AbstractGenericModel
  x::Float64
  y::Float64

  Model() = new()
end


function (model::Model)()
  rng = Libtask.get_dynamic_scope()
  model.x = rand(rng, Beta(1,1))
  Libtask.produce(model.x)

  rng = Libtask.get_dynamic_scope()
  model.y = rand(rng, Normal(0, model.x))
  Libtask.produce(model.y)
end

rng = AdvancedPS.TracedRNG()
Random.seed!(rng, 10)
model = Model()

trace = AdvancedPS.Trace(model, rng)
# Sample `x`
AdvancedPS.advance!(trace)

trace2 = AdvancedPS.fork(trace)

key = AdvancedPS.state(trace.rng.rng)
seeds = AdvancedPS.split(key, 2)

Random.seed!(trace.rng, seeds[1])
Random.seed!(trace2.rng, seeds[2])

# Inherit `x` across independent particles
AdvancedPS.advance!(trace)
AdvancedPS.advance!(trace2)

println("Parent particle")
println(trace.model.f)
println("Child particle")
println(trace2.model.f)
println("Model with actual sampled values is in ctask.fargs")
println(trace2.model.ctask.fargs[1])

# Create reference particle
# Suppose we select the previous 'child' particle
ref = AdvancedPS.forkr(trace2)
println("Did we keep all the generated values ?")
println(ref.model.f) # If we just copy the tapedtask, we don't get the sampled values in the `Model`
# Note, this is only a problem when creating a reference trajectory, 
# sampled values are properly captured during the execution of the task

FredericWantiez avatar Mar 25 '25 19:03 FredericWantiez

println(ref.model.f) # If we just copy the tapedtask, we don't get the sampled values in the Model Note, this is only a problem when creating a reference trajectory,

@FredericWantiez can we store trace.rng inside TapedTask instead of trace? That way, when copying a TapedTask, we will copy the trace.rng.

yebai avatar Mar 26 '25 12:03 yebai

@willtebbutt I think 2) might also be a problem for Turing, when looking at this part: https://github.com/TuringLang/Turing.jl/blob/afb5c44d6dc1736831f45620328c9d5681748111/src/mcmc/particle_mcmc.jl#L140-L142

FredericWantiez avatar Apr 01 '25 20:04 FredericWantiez

Two small issues I found cleaning up the tests.

Libtask returns a value after the last produce statement:

function f()
  Libtask.produce(1)
  Libtask.produce(2)
end

t1 = TapedTask(nothing, f)
consume(t1)  # 1
consume(t1)  # 2
consume(t2)  # 2 (?)

Libtask doesn't catch some of the produce statements:

mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
  a::Float64
  b::Float64

  NormalModel() = new()
end

function (m::NormalModel)()
  # First latent variable.
  rng = Libtask.get_dynamic_scope()
  m.a = a = rand(rng, Normal(4, 5))

  # First observation.
  AdvancedPS.observe(Normal(a, 2), 3)

  # Second latent variable.
  rng = Libtask.get_dynamic_scope()
  m.b = b = rand(rng, Normal(a, 1))

  # Second observation.
  AdvancedPS.observe(Normal(b, 2), 1.5)
  return nothing
end

rng = AdvancedPS.TracedRNG()
t = TapedTask(rng, NormalModel())

consume(t) # some float
consume(t) # 0 (?)
consume(t) # 0 (?)

this works fine if I call Libtask.produce explicitly instead of observe

EDIT: Changing observe to something like this seems to work:

function AdvancedPS.observe(dist::Distributions.Distribution, x)
    Libtask.produce(Distributions.loglikelihood(dist, x))
    return nothing
end

FredericWantiez avatar Apr 05 '25 09:04 FredericWantiez

If we store both rng and varinfo in the scoped variable, then the following suggestions will address (2):

  • store varinfo in the Trace struct, then change here to Libtask.set_dynamic_scope!(trace.model.ctask, (trace.rng, trace.varinfo))
  • change here and here to rng, varinfo = Libtask.get_dynamic_scope()
  • change here to transition = SMCTransition(model, particle.varinfo, weight)

yebai avatar Apr 08 '25 14:04 yebai

That should work, I have a branch against Turing that tries to do this but seems like one copy is not quite correct.

The other solution is to use one replay step before the transition, to repopulate the varinfo properly:

    new_particle = AdvancedPS.replay(particle)
    transition = SMCTransition(model, new_particle.model.f.varinfo, weight)
    state = SMCState(particles, 2, logevidence)
    return transition, state

FredericWantiez avatar Apr 08 '25 19:04 FredericWantiez

@willtebbutt running models against this PR I see a large performance drop:

using Libtask
using AdvancedPS
using Distributions
using Random

mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
    a::Float64
    b::Float64

    NormalModel() = new()
end

function (m::NormalModel)()
    # First latent variable.
    rng = Libtask.get_dynamic_scope()
    m.a = a = rand(rng, Normal(4, 5))

    # First observation.
    AdvancedPS.observe(Normal(a, 2), 3)

    # Second latent variable.
    rng = Libtask.get_dynamic_scope()
    m.b = b = rand(rng, Normal(a, 1))

    # Second observation.
    AdvancedPS.observe(Normal(b, 2), 1.5)
end

@time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false)

On master:

1.816623 seconds (5.92 M allocations: 311.647 MiB, 1.52% gc time, 96.09% compilation time)

On this PR:

72.085056 seconds (369.62 M allocations: 17.322 GiB, 2.83% gc time, 77.21% compilation time)

FredericWantiez avatar Apr 08 '25 19:04 FredericWantiez

Thanks for the data point. Essentially the final item on my todo list is sorting out various type inference issues in the current implementation. Once they're done, we should see substantially improved performance.

willtebbutt avatar Apr 09 '25 07:04 willtebbutt

That should work, I have a branch against Turing that tries to do this but seems like one copy is not quite correct.

The varinfo variable is updated during inference. I think we have to carefully ensure the correct varinfo is stored in the scoped variable.

cc @mhauru @FredericWantiez

yebai avatar Apr 09 '25 10:04 yebai

@willtebbutt running models against this PR I see a large performance drop:

@FredericWantiez I'm finally looking at sorting out the performance of the Libtask updates. I'm struggling to replicate the performance of your example on the current versions of packages, because I find that it errors. My environment is

(jl_4fXu3W) pkg> st
Status `/private/var/folders/z7/0fkyw8ms795b7znc_3vbvrsw0000gn/T/jl_4fXu3W/Project.toml`
  [576499cb] AdvancedPS v0.6.1
  [31c24e10] Distributions v0.25.118
  [6f1fad26] Libtask v0.8.8
  [9a3f8284] Random v1.11.0

I tried it on LTS and 1.11.4.

In particular, I'm seeing the error:

ERROR: BoundsError: attempt to access 0-element Vector{Any} at index [1]
Stacktrace:
  [1] throw_boundserror(A::Vector{Any}, I::Tuple{Int64})
    @ Base ./essentials.jl:14
  [2] getindex
    @ ./essentials.jl:916 [inlined]
  [3] _infer(f::NormalModel, args_type::Tuple{DataType})
    @ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:45
  [4] Libtask.TapedFunction{…}(f::NormalModel, args::AdvancedPS.TracedRNG{…}; cache::Bool, deepcopy_types::Type)
    @ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:72
  [5] TapedFunction
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:62 [inlined]
  [6] _
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
  [7] TapedFunction
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
  [8] #TapedTask#15
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:76 [inlined]
  [9] TapedTask
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:70 [inlined]
 [10] LibtaskModel
    @ ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:27 [inlined]
 [11] AdvancedPS.Trace(::NormalModel, ::AdvancedPS.TracedRNG{UInt64, 1, Random123.Philox2x{UInt64, 10}})
    @ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:49
 [12] (::AdvancedPSLibtaskExt.var"#2#3"{NormalModel, Nothing, Bool, Int64})(i::Int64)
    @ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:140
 [13] iterate
    @ ./generator.jl:48 [inlined]
 [14] _collect(c::UnitRange{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
    @ Base ./array.jl:811
 [15] collect_similar
    @ ./array.jl:720 [inlined]
 [16] map
    @ ./abstractarray.jl:3371 [inlined]
 [17] step(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, state::Nothing; kwargs::@Kwargs{})
    @ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:134
 [18] macro expansion
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:0 [inlined]
 [19] macro expansion
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/logging.jl:16 [inlined]
 [20] mcmcsample(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, num_warmup::Int64, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:142
 [21] mcmcsample
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:107 [inlined]
 [22] #sample#20
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:59 [inlined]
 [23] sample
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:52 [inlined]
 [24] #sample#19
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:21 [inlined]
 [25] sample
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:18 [inlined]
 [26] macro expansion
    @ ./timing.jl:581 [inlined]
 [27] top-level scope
    @ ./REPL[10]:1
Some type information was truncated. Use `show(err)` to see complete types.

Any idea whether I'm doing something wrong?

willtebbutt avatar Apr 15 '25 09:04 willtebbutt

But, additionally, the latest version of the PR should address the various performance issues we previously had. There is one important change though: you need to pass a type to Libtask.get_dynamic_scope, which should be the type of the thing that it's going to return. We need this because there's no way to make the container typed (I assume that the previous implementation had a similar limitation). The docstring has been updated to reflect the changes.

willtebbutt avatar Apr 15 '25 10:04 willtebbutt

@willtebbutt if you're testing against the released version of Libtask/AdvancedPS you need to explicitly pass the RNG in the model definition, something like that:

function (model::Model)(rng::Random.AbstractRNG) # Add the RNG as argument
  model.sig = rand(rng, Beta(1, 1)) 
  Libtask.produce(model.sig)

  model.mu = rand(rng, Normal())
  Libtask.produce(model.mu)
end

FredericWantiez avatar Apr 15 '25 17:04 FredericWantiez

This now runs faster with AdvancedPS (dc5e594) and Libtask (8e7f784)

# run once to triger compilation 
julia> @time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false);
  2.986750 seconds (7.31 M allocations: 380.449 MiB, 0.88% gc time, 99.51% compilation time)

# second time runs faster
julia> @time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false);
  0.012714 seconds (32.85 k allocations: 18.581 MiB, 19.87% gc time)
Code
(@temp) pkg> add AdvancedPS#fred/libtask-revamp
(@temp) pkg> add Libtask#wct/refactor

using Libtask
using AdvancedPS
using Distributions
using Random

mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
    a::Float64
    b::Float64

    NormalModel() = new()
end

function (m::NormalModel)()
    # First latent variable.
    T = AdvancedPS.TracedRNG{UInt64, 1, AdvancedPS.Random123.Philox2x{UInt64, 10}}; 
    rng = Libtask.Libtask.get_taped_globals(T)
    m.a = a = rand(rng, Normal(4, 5))

    # First observation.
    AdvancedPS.observe(Normal(a, 2), 3)

    # Second latent variable.
    rng = Libtask.Libtask.get_taped_globals(T)
    m.b = b = rand(rng, Normal(a, 1))

    # Second observation.
    AdvancedPS.observe(Normal(b, 2), 1.5)
end

@time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false)

yebai avatar Apr 22 '25 17:04 yebai

New Libtask has now been released. CI failures reveal that a few more fixes are required from AdvancedPS.

yebai avatar May 09 '25 19:05 yebai

AdvancedPS.jl documentation for PR #114 is available at: https://TuringLang.github.io/AdvancedPS.jl/previews/PR114/

github-actions[bot] avatar May 09 '25 19:05 github-actions[bot]

Most of the test failures seemed to be a case renaming a function (although the opaque error message "Unbound GlobalRef not allowed in value position" that they were yielding concerns me somewhat).

The remaining test failure is about addreference!/current_trace, which I'm a bit confused about. I thought we only wanted to store in the task local storage/taped_globals the RNG. However, addreference! stores the whole trace object. In Turing.jl this is used to get the RNG from the trace (something we can replace with getting it from taped_globals), but also to get a VarInfo from the trace. See these lines in Turing.jl's mcmc/particle_mcmc.jl:

function trace_local_varinfo_maybe(varinfo)
    try
        trace = AdvancedPS.current_trace()
        return trace.model.f.varinfo
    catch e
        # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
        if e == KeyError(:__trace) || current_task().storage isa Nothing
            return varinfo
        else
            rethrow(e)
        end
    end
end

and

function DynamicPPL.assume(
    rng,
    spl::Sampler{<:Union{PG,SMC}},
    dist::Distribution,
    vn::VarName,
    _vi::AbstractVarInfo,
)
    vi = trace_local_varinfo_maybe(_vi)
    [...]

Why do we need to do this? If we do need to do this, do we have to rework our use of taped_globals to store not only an RNG but also a VarInfo?

mhauru avatar Jun 02 '25 08:06 mhauru

Briefly, addreference!/current_trace can be safely removed. These are replaced by set_taped_globals/ get_taped_globals.

(A full explanation requires me explaining the history of Libtask...)

EDIT: (rng, trace.model.f.varinfo) can be saved to TapedTask directly using set_taped_globals. This replaces the old design of storing them in task local storage and keeping a reference to task in each TapedTask (i.e. addreference!)

yebai avatar Jun 02 '25 08:06 yebai

I've just started pushing onto your PR @FredericWantiez, I hope you don't mind. I can make a separate PR if you prefer.

Tests should now pass, but please don't review or merge yet. I don't trust that I've done this right until I see it work with Turing.jl. I'll try to get that done locally now.

mhauru avatar Jun 06 '25 08:06 mhauru

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Please upload report for BASE (main@1ad89ec). Learn more about missing BASE report.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #114   +/-   ##
=======================================
  Coverage        ?   96.27%           
=======================================
  Files           ?        8           
  Lines           ?      429           
  Branches        ?        0           
=======================================
  Hits            ?      413           
  Misses          ?       16           
  Partials        ?        0           

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Jun 06 '25 08:06 codecov[bot]

Could I get some reviews on this please? I've now got some of Turing's ParticleGibbs tests passing locally using this, which gives me more confidence that it's largely correct (one bug was found in the process). Making a release of this would help run the full test suite of Turing.jl and see if anything else comes up.

mhauru avatar Jun 19 '25 14:06 mhauru

I would request @FredericWantiez's review, but you are the PR owner so I can't.

mhauru avatar Jun 19 '25 14:06 mhauru

Thanks @mhauru, I did a careful pass. It looks correct to me. Let's simplify more here so the code becomes less mysterious.

yebai avatar Jun 19 '25 15:06 yebai

Thanks @yebai. Sorry about the general messiness here, since my understanding of the code is poor I only tried to do the minimal edits needed to get it to work.

mhauru avatar Jun 19 '25 15:06 mhauru

Sorry about the general messiness here, since my understanding of the code is poor I only tried to do the minimal edits needed to get it to work.

No worries, I am probably the one to blame for the messy code. My suggestions only encourage you to do more to simplify the code now that some heuristics have become unnecessary.

yebai avatar Jun 19 '25 15:06 yebai

I've removed the addreference! function by having the backreference added when a LibtaskTrace object is created. This makes it harder to forget to add it, and makes our interface simpler. The backreference itself still remains, because we use it to access the VarInfo of the trace. While it's true that we store the trace in the TapedGlobals.other field only to have access to trace.model.f.varinfo, we unfortunately can't simplify that by storing the VarInfo directly, because Turing isn't a dependency of AdvancedPS and thus we have no visibility into things like the structure of a DPPL.Model object.

mhauru avatar Jun 19 '25 16:06 mhauru