DynamicPPL.jl icon indicating copy to clipboard operation
DynamicPPL.jl copied to clipboard

Do we need to store linked variables in VarInfo?

Open mhauru opened this issue 9 months ago • 10 comments

As far as I can see, the only reason to do linking is because samplers need it. Samplers moreover interface with a vector of floats, not a VarInfo. So really the only place when we need linked variables is when we call getindex_internal or something like vi[:]. If I'm mistaken about this, and there's a different need, all of the below may be moot, so do say so.

This lead me to wonder why we bother storing linked variables at all, rather than converting at getindex/setindex time. Currently our philosophy seems to be that samplers give and take vectors of linked floats, and when interfacing with a sampler we typically call vi[:] and unflatten(vi, vec_of_floats) to read/write those vectors of floats from/to the internal storage of the VarInfo, where the variables are linked if link has been called on vi at some point in its history. Then whenever we need the unlinked versions of variables, such as when calling logpdf, we convert the internal, linked storage to the unlinked one.

An alternative approach would be to say that VarInfo only ever stores unlinked variables, and rather than calling link on vi once at the beginning of sampling and using getindex/setindex, if you would want linked variables you would need to call getindex_linked and setindex_linked or unflatten_linked, and the linking would happen on the fly in that call.

Needs linked                                Needs unlinked 
                                                           
┌──────────┐         ┌───────────┐          ┌─────────────┐
│          │         │           │          │             │
│ Sampler  │ ◄─────► │  VarInfo  │  ◄─────► │ logpdf call │
│          │         │           │          │             │
└──────────┘   ▲     └───────────┘     ▲    └─────────────┘
               │                       │                   
               │                       │                   
                                                           
          Alternative:             Current:                
          (Un)link here            (Un)link here           

I'm not sure if the alternative is better, but I do see some benefits to it:

  • VarInfos would have less state that they carry. Currently when you call getindex_internal or unflatten you need to know/remember whether this VarInfo has been linked at some point in its history (or check for it) to know what you'll get.
  • Sampler interface code would be more explicit, at the call-site, as to whether it wants linked variables or not. This would make the interface code a bit easier to understand.
  • VarInfo would be simpler. We could drop the istrans flags, and never worry about things like having some variables linked and others not. Note that we would still need to carry the linking transformations in the VarInfo, to know how to transform when someone calls e.g. getindex_linked.

This could also help us move towards something like SimpleVarInfo or the @parameters struct of https://github.com/TuringLang/Turing.jl/issues/2492, where the storage in VarInfo would more closely match how the model sees the variables. I haven't thought about this carefully, but I could imagine this leading to some performance gains and maybe code simplification.

This is currently more an idea to be discussed than a proposal to implement. Any thoughts? Downsides to linking at getindex/setindex time?

mhauru avatar Mar 10 '25 11:03 mhauru

I don't recall a good reason for the current design. It could be something inherited from the good old days. One vague motivation was that bijectors used to be constructed on the fly at the logpdf call site, so that is the only place we can perform linking.

@torfjelde @sunxd3, are you aware of any specific reasons for preferring the current approach?

yebai avatar Mar 11 '25 18:03 yebai

I don't know of a reason for it either. (JuliaBUGS always stores the untransformed values, but JuliaBUGS is simpler than DynamicPPL too.)

sunxd3 avatar Mar 12 '25 07:03 sunxd3

There are ofc efficiency and numerical accuracy concerns with this (see below), but the main issue, which, I don't quite see a way around atm, is variable constraints where the transformations require knowledge of the realizations of the other variables in the model.

Example is the demo_dynamic_support model in the TestUtils. Here you have the support of one variable depending on the value the other variable takes on, leading to the transformation used in (inv)linking potentially changing between evaluations.

How would we handle this if we don't save the transformed variables in the varinfo?

And regarding efficiency and numerical accuracy.

We have to do two things with transformed variables:

  1. Compute constrained -> unconstrained and vice versa.
  2. Compute logabsdetjac correction to the logpdf.

What you're suggesting would lead to at least one more call of (1), since we would now have to do unconstrained (in the sampler) -> constrained (in the logpdf) -> unconstrained (in the sampler).

Moreover, with_logabsdet_jacobian can be used to compute both (1) and (2) in one go, which is more efficient than doing them separately.

For numerical accuracy, I believe we ran into some issues with LKJBijector or similar due to unnecessary passes through constrained -> unconstrained. I don't remember the details and so it might be unrelated, but issues on this should be floating around somewhere.

Both of the above is ofc worsened when we start involving AD too.

Whether it's worth it or not, is a different question.

torfjelde avatar Mar 12 '25 11:03 torfjelde

Related: https://github.com/TuringLang/Turing.jl/issues/1583

yebai avatar Mar 12 '25 21:03 yebai

Thanks @torfjelde, that's really helpful. I'll need to give both dynamic models and with_logabsdet_jacobian a think. For dynamic models I wonder if this could be handled by storing in the VarInfo not a bijector, but a function that returns a bijector given the values of the other variables. This might require also storing the order in which variables appear in the model (related to #833) so that the aforementioned function would only depend variables that came before. This might get clunky.

mhauru avatar Mar 13 '25 10:03 mhauru

This has also made me realise that what we do now with dynamic models is a unintuitive as well. Namely,

@model function demo_dynamic_constraint()
    m ~ Normal()
    x ~ truncated(Normal(), m, Inf)
end

when sampled in unconstrained space is such that the value of x may change even if the sampler does not change it, if m is changed. I've probably had this same realisation around 3 times now since I started on the project, but I seem to not internalise it, and I think it's because I don't fundamentally understand what a VarInfo really is.

mhauru avatar Mar 13 '25 10:03 mhauru

For dynamic models I wonder if this could be handled by storing in the VarInfo not a bijector, but a function that returns a bijector given the values of the other variables

Hmm, not sure I'm fully on board with this being simpler than allowing the internal representation of VarInfo to be different from that of the model 😕

when sampled in unconstrained space is such that the value of x may change even if the sampler does not change it, if m is changed

This what you want to happen though, right?

torfjelde avatar Mar 13 '25 12:03 torfjelde

when sampled in unconstrained space is such that the value of x may change even if the sampler does not change it, if m is changed

This what you want to happen though, right?

Unfortunately, no. This is desirable from a user perspective, but it creates difficulties for HMC: warm-up in HMC generally doesn't work very well when the support of model parameters changes during inference. That said, we can keep this feature, but better to provide an informational message to warn the modeller to modify their model to avoid such uses.

yebai avatar Mar 13 '25 16:03 yebai

Unfortunately, no.

I meant from the perspective of functionality; this will ofc make things different for samplers, but one could definitively implement samplers which took these things into account, and so it seems a bit drastic to not allow such a model to be used at all.

But agree with informing the user 👍 Think this is the sort of case where people should only be doing it if they really know what they're doing

torfjelde avatar Mar 13 '25 17:03 torfjelde

This what you want to happen though, right?

I wouldn't say that it's what I want, but I can't think of a better alternative. I find that this sort of stuff makes it hard to reason about how things work, but maybe it's irreducible complexity (rather than reducible, i.e. fixable with a better design).

mhauru avatar Mar 13 '25 17:03 mhauru