Libtask.jl icon indicating copy to clipboard operation
Libtask.jl copied to clipboard

Completely Refactor

Open willtebbutt opened this issue 8 months ago • 5 comments

Per #177 this is an attempt to significantly refactor the internals for robustness.

Todo:

  • [x] understand what the desired semantics are regarding copying, per the discussion in #178
  • [x] determine whether the task field of a TapedTask ought actually to be part of the public interface -- issue to discussion open at https://github.com/TuringLang/AdvancedPS.jl/issues/113
  • [x] implement an opt-in mechanism for nested calls containing produce statements
  • [x] tidy up main body of code
  • [x] document how PhiNodes are handled more thoroughly (include a worked example)
  • [x] document the intent of the code + the high-level structure. In particular, re-work the README to explain what the public interface is, and what it does.
  • [x] get all CI passing
  • [x] check that we can make AdvancedPS and Turing work using this PR
  • [ ] kwarg support
  • [x] improve nesting mechanism to ensure that we always catch everything
  • [ ] optimise nested calls (currently throwing away all return type information to get prototype running)
  • [ ] benchmark to ensure no regressions
  • [ ] add BBCode section to Mooncake docs so that this implementation is intelligible to everyone
  • [ ] add more high-level documentation

Linked Issues:

  • https://github.com/TuringLang/Libtask.jl/issues/177
  • https://github.com/TuringLang/Libtask.jl/issues/178
  • https://github.com/TuringLang/AdvancedPS.jl/issues/113

Closes #165 Closes #167 Closes #171 Closes https://github.com/TuringLang/Libtask.jl/issues/176

willtebbutt avatar Feb 27 '25 19:02 willtebbutt

@yebai regarding the discussions in the meeting today, I wonder if the following might work as a canonical way to modify specific bits of the function in ways which are transparent to the user. We define a higher-order function

function replace_args!(f::F, t::TapedTask) where {F}
    f(t.args)
    return nothing
end

I want to do this via a specific function as I don't want the args field to be part of the public interface.

To e.g. copy a task and split its rng, you might do something like the following:

# Define your function. Anywhere you want the "correct" rng value, you query the `rng_ref[]`.
function f(rng_ref::Ref{<:AbstractRNG}, other_args...)
    produce(rand(rng_ref[]))
    produce(rand(rng_ref[]))
end

# Construct a TapedTask.
t = TapedTask(f, Ref(initial_rng), other_args...)

# Copy it.
t_copy = copy(t)

# We defined `f`, and we know for sure that the first argument contains the RNG that
# we want to split. We just modify this.
replace_args!(t_copy) do args
    rng_ref = args[1]
    rng_ref[] = split(rng_ref[])
    return nothing
end

We would have to be careful to document how to / how not to use this mechanism (in particular, to make it clear what gets updated and what doesn't), but I think it might be able to do what we need.

willtebbutt avatar Mar 03 '25 16:03 willtebbutt

function replace_args!(f::F, t::TapedTask) where {F}
   f(t.args)
   return nothing
end

I confirm this would work.

# Define your function. Anywhere you want the "correct" rng value, you query the `rng_ref[]`.
function f(rng_ref::Ref{<:AbstractRNG}, other_args...)
   produce(rand(rng_ref[]))
   produce(rand(rng_ref[]))
end

IIUC, such an approach requires replacing all usage of VarInfo and RNG with a reference. This would require significant modification in DynamicPPL. Or are you thinking about automatically implementing an extra transform in Libtask to replace rng with rng_ref[]?

EDIT: AdvancedPS and Turing could retrieve particle-specific varinfo and rng from TapedTask. This approach is akin to trace_local_[varinfo|rng]_maybe, but moving task local storage into the TapedTask struct by replacing AdvancedPS.current_trace() and AdvancedPS.addreference!.

yebai avatar Mar 03 '25 17:03 yebai

Cc @mhauru can provide details on what a model function looks like.

yebai avatar Mar 03 '25 17:03 yebai

I've been thinking a bit further about this, and I wonder whether we ought just to use the new ScopedValues feature in Base? We could use this, and then just move all trace-specific items from task-local storage to a trace-local.

e.g. if we add a variable to Libtask

const task_cache = ScopedValue{Any}(nothing) # maybe change types for stability

then the implementation of consume can be something like

function consume(t::TapedTask)
    with(task_cache => t.cache) do
        return t.mc(t.args...)
    end
end

Then if you want to use something stored in the cache field of a TapedTask, you just refer to Libtask.task_cache in the function you've got produce statements in, rather than current_task().storage.

edit: this feature only officially landed in Julia 1.11, but this is based on an implementation which supports 1.8+, so we should be able to use one or the other selectively.

edit2: alternatively, we only apply the upgrade in this PR to 1.11 and up.

willtebbutt avatar Mar 04 '25 15:03 willtebbutt

Status Update: nested calls now basically work, I just need to tidy up the implementation a bit. I've also tidied up the whole implementation and documented it much more thoroughly. There's more that needs doing, but it's almost at he point that we can try and get it to work with AdvancedPS + Turing, to make sure that we do indeed have all of the features that we think we need.

willtebbutt avatar Mar 11 '25 19:03 willtebbutt

@FredericWantiez I think your problems should now be fixed. Could you let me know if it's working okay for you on the AdvancedPS side of things?

willtebbutt avatar Apr 08 '25 14:04 willtebbutt

Yes, that fixed the issue with the return value ! Thanks

FredericWantiez avatar Apr 08 '25 19:04 FredericWantiez

@mhauru @sunxd3 apologies in advance for the rather large PR.

When reviewing, I suggest that you ignore everything in src/bbcode.jl and src/utils.jl -- they are copies of utility functionality from Mooncake. Additionally, if anything is unclear, please let me know rather than trying to spend too long to understand it if you feel that it is under-documented -- I'm keen to add additional docs etc, so it's really rather helpful to know where things aren't straightforward to understand.

willtebbutt avatar Apr 15 '25 16:04 willtebbutt

@sunxd3 thanks for the feedback, and apologies for taking a while to address it.

I think get_taped_globals using task_local_storage is a great idea, but maybe explain a bit how and why to use task_local_storage? (I recall this is to support some necessary state for SMC like RNG?)

I think I discuss this in the docstring for TapedTask -- do you think it would be helpful if I were to embellish the explanation a bit?

It tooks me a bit of effort to understand how refs is used to enable isolation between execution of copied TapedTask, some high level description would help a lot (maybe this already exists and I just missed it)

This is a good point. I don't think I discuss this properly anywhere -- I agree that it needs a good high-level explanation. Will add and ask for further review.

willtebbutt avatar May 05 '25 13:05 willtebbutt

~~Actually, @sunxd3 , where do you think might be a good location for the docs for the refs? I'm not 100% sure where to put them, and might be easier for someone who's head is less in the code.~~

willtebbutt avatar May 05 '25 14:05 willtebbutt

@sunxd3 @mhauru this is ready for another pass when you get some time. Apologies for taking a while to address your comments.

willtebbutt avatar May 05 '25 14:05 willtebbutt

To add to Xianda and Marku's comments, I suggest adding an implementation note to explain the key ideas, specifically how phi-nodes are handled and the copious use of Refs and opaque closures. This would significantly aid in fixing compatibility issues with future Julia releases.

yebai avatar May 06 '25 16:05 yebai

@yebai I've linked to the current phi node explainer above -- if you think it needs a more thorough explanation please let me know.

I take your point regarding the Refs and OpaqueClosures though. I'll add some extra docs there.

willtebbutt avatar May 09 '25 11:05 willtebbutt

Okay. All comments are now addressed, and I believe that this is good to go. @yebai happy to merge if you are.

willtebbutt avatar May 09 '25 12:05 willtebbutt

Libtask.jl documentation for PR #179 is available at: https://TuringLang.github.io/Libtask.jl/previews/PR179/

github-actions[bot] avatar May 09 '25 12:05 github-actions[bot]

Thanks @mhauru

willtebbutt avatar May 09 '25 15:05 willtebbutt

Hurray! 🎉 Thank you so much @willtebbutt, fantastic work!

mhauru avatar May 12 '25 07:05 mhauru