Should we have a context to indicate that we're not performing inference?
Sometimes we know that we're not performing inference in a model, in which case it can be useful to have the model perform some additional computations.
For example, people often end up using Dirac to save quantities that are deterministic when conditioned on the variables in a Chains (e.g. https://github.com/TuringLang/Turing.jl/issues/2059), but this can then cause issues since samplers will also consider these as random variables to be sampled. generated_quantities provides a partial solution to this, but
- It doesn't have the same behavior as
~, which people often want. - It requires an additional pass over the chain after sampling.
Addressing (2) is non-trivial as it requires changing both the internals of varinfos in addition to how transitions are constructed in Turing.jl.
But addressing (1) could be done by adding a simple NotInferringContext + a macro @isinferring, which simply translates into a call isinferring(__context__, __varinfo__) or something. We could then either provide an alternative to generated_quantities which makes use of this, or we could include this in generated_quantities somehow (I'm leaning towards the former).
An alternative way of addressing (1) is to just provide a convenient way to convert a Matrix{<:NamedTuple} into a Chains, which I have existing code to do, but the annoying bit with this is that it's "asking a lot" of the user (they need to explicitly filter out the variables in return they don't want in the chain, and then also make sure to return a NamedTuple, etc.). Hence I think I'm leaning towards supporting both approaches :shrug:
Bumping this in the context of prediction.
It was attempted in #589 but got downvoted for now, as it was deemed undesirable to add more macros (see discussion in that PR)
A thing to note is that if you want to do within chain parallel then this is a handy feature.
A reason is that its much easier to use @addlogprob! in patterns with (say) @threads over a loop of data observations; but then you sacrifice the generative quality of the Turing PPL. Having an "Not doing inference right now" context squares that circle.
Indeed. I've had to use similar approaches in the past due to the computational complexity of the model.
@yebai worth revisiting #589 ?
over a loop of data observations; but then you sacrifice the generative quality of the Turing PPL. Having an "Not doing inference right now" context squares that circle.
Not sure I understand this. Can you clarify what you mean by "sacrifice the generative quality of the Turing PPL"? @SamuelBrand1
I believe @SamuelBrand1 means that you can drop the usage of ~, which loses the niceties that it brings, e.g. sample from prior, in favour of the faster @addlogprob! path during inference (like we did extensively in the Epimap project to make it work @yebai )
Just saw this, but yes exactly!
@torfjelde and I have some further discussion today. While I agree this is a useful feature, I am not yet convinced there is a safe and robust design.
I suggest that we use something more specific than post inference processing, e.g., @evaluating_logprob, to determine whether the current evaluation is for computing log prob. This @evaluating_logprob approach has the advantage of not being closely coupled to a specific inference algorithm or post-inference analysis.
Suggestions on a better design are welcome!
I think we could use something like Turing.@evaluation_mode() to determine the current evaluation mode, e.g., log density, prediction, etc.
See, also, https://github.com/TuringLang/Turing.jl/issues/2239
Cc @devmotion @sunxd3 @mhauru for more comments.
Cc @seabbs
Tried to summarize some of my thoughts on the topic.
There are two issues here:
- We don't have well-defined "modes" of operation, and hence can't properly answer the question "are we performing post-inference analysis or inferring?".
- What is the best syntax to enable all these usages?
Modes of operation
- Back in the day, we were always sampling, and we couldn't really tell whether we were sampling to perform inference or if we were just sampling from the prior to, say,
predictor do something else. - Now we have different "modes" in the sense that we're not always sampling, but we still don't have a clear-cut definition of which is which. For example, some sampler implementations still make use of
SamplingContextwhile others make use ofLogDensityFunctionwithout aSamplingContext.
This makes it so that it's difficult to answer the question "are we doing some post-inference analysis or are we inferring?"
But for most of the applications we have in mind, I think it's sufficient to only determine that we're, say, performing parameter extraction, i.e. evaluating with ValuesAsInModelContext, which is currently only used in the very last "layer" before returning the samples to the user.
Similarly for methods such as generated_quantities and predict, this should be sufficient.
In such a case, we could have
# NOTE: `has_context_type` is not implemented but is easy enough to do
is_post_inference(context::AbstractContext) = has_context_type(context, ValuesAsInModelContext)
But doing something completely general like @evaluation_mode seems non-trivial as we don't have clear definitions of what different modes mean :confused:
Best syntax
Generally two categories of approaches:
- Macro-based: super-simple to implement because it doesn't have to touch all the internals. Example:
@iid xs ~ filldist(Normal(), n). - Wrapper-based: provide wrappers of distributions, e.g. IID would be
xs ~ iid(Normal(), n), and then resolve this in the tilde-pipeline as demonstrated in #595 (though we would need a "generalization" of that PR, as it's specifically about handling latent distributions).
Macro-based
Macro-based methods have the benefit that we can defer the resolution of varnames, etc. to the @model macro, i.e. statements of the form @iid xs ~ ... directly results in statements of the form xs[i] ~ ....
Buuut it's a) ugly, and b) means that xs is not always "in the model" (in the sense that we end up just seeing xs[1] ~ ..., etc. rather than xs ~ ..., which can cause confusion with other functionalities, e.g. condition; see below).
Wrapper-based
This a bit more annoying to implement, but if done right it will be a) nicer syntactically, and b) allow better composition with the rest of the code-base.
Hence, this is probably the way to go.
Further issues / questions
- [ ] Representing something like
iiddoes mean that there's some confusion about howconditionworks. For example, if I have something likexs ~ iid(Normal(), n), do Iconditionxsusingmodel | (xs = ..., )or do I do it on a per-element basis, i.e.model | (@varname(xs[1]) => ..., ...)?- I'm of the opinion of the former, but we need to document this / warn users in appropriate places. And this would work nicely with the wrapper-based approach.
I buy @torfjelde's reasoning here. I also agree that the former of the two conditioning syntax is better. For one, xs[1] doesn't appear in the program. Also, I think in the case of a iid wrapper, all the elements should all be conditioned or all not. Checking this and forcing user to do is just not good experiences, in my opinion.
Of the two proposals the wrapper based approach seems like it would work better for our use case. However I am a little bit confused why it is needed at all for things that in theory already return idd components (i.e filldist, .~). Couldn't these instead be modified to internally check the context and change behaviour accordingly (i.e automate using a idd wrapper?)
Also, I think in the case of a iid wrapper, all the elements should all be conditioned or all not.
This seems like it would be very limiting?
However I am a little bit confused why it is needed at all for things that in theory already return idd components (i.e filldist, .~).
filldist isn't just used for IID variables unfortuantely 😕
Couldn't these instead be modified to internally check the context and change behaviour accordingly (i.e automate using a idd wrapper?)
ATM there's no "context" to indicate this. As in, if you write
x ~ filldist(Normal(), 10)
in your model, Turing.jl / DynamicPPL.jl doesn't know whether this is a) 10 iid variables, or b) a 10-dim multivariate variable implemented in an efficient manner. Nor is there realistically a way for us to check this, unfortunately 😕 Hence we need something explcit, e.g. iid wrapper as discussed above, which will be like filldist but also very clear about what it represent
filldist isn't just used for IID variables unfortuantely 😕
Ah I guess I was confused by this: https://turinglang.org/docs/tutorials/docs-13-using-turing-performance-tips/index.html#special-care-for-zygote-and-tracker
a) 10 iid variables, or b) a 10-dim multivariate variable implemented in an efficient manner.
My expectation was that internally filldist did know this and that it was within filldist you could put this switch but it sounds like that isn't the case.
Closed in favour of https://github.com/TuringLang/DynamicPPL.jl/issues/810