DynamicPPL.jl icon indicating copy to clipboard operation
DynamicPPL.jl copied to clipboard

Adding StatsBase.predict to the API

Open sethaxen opened this issue 2 years ago • 7 comments

In Turing, StatsBase.predict is overloaded to dispatch on DynamicPPL.Model and MCMCChains.Chains (https://github.com/TuringLang/Turing.jl/blob/d76d914231db0198b99e5ca5d69d80934ee016b3/src/inference/Inference.jl#L532-L564). This effectively does batch prediction, conditioning the model on each draw in the chains and calls rand on the model. We also want to do the same thing for InferenceData (see #465).

It would be convenient if StatsBase.predict was added to the DynamicPPL API. It's already an indirect dependency of this package. As suggested by @devmotion in https://github.com/TuringLang/DynamicPPL.jl/pull/465#discussion_r1111647150, its default implementation could be to just call rand for a conditioned model:

StatsBase.predict(rng::AbstractRNG, model::DynamicPPL.Model, x) = rand(rng, condition(model, x))
StatsBase.predict(model::DynamicPPL.Model, x) = predict(Random.default_rng(), model, x)

sethaxen avatar Feb 20 '23 09:02 sethaxen

Maybe this could even be part of AbstractPPL and be defined on AbstractPPL.AbstractProbabilisticProgram: condition is part of its API, only rand is not clearly specified there yet (probably should be done anyway).

devmotion avatar Feb 20 '23 10:02 devmotion

Yeah, makes sense.

sethaxen avatar Feb 20 '23 10:02 sethaxen

I'm down with this, but it's worth pointing out that just calling rand(rng, condition(model, x)) is probably not the greatest idea as it defaults to NamedTuple which can blow up compilation times for many models.

And regarding adding to APPL; we need to propagate that change back to v0.5 too then, because v0.6 is currently not compatible with DPPL (see #440).

torfjelde avatar Feb 20 '23 10:02 torfjelde

I'm down with this, but it's worth pointing out that just calling rand(rng, condition(model, x)) is probably not the greatest idea as it defaults to NamedTuple which can blow up compilation times for many models.

Would rand(rng, OrderedDict, condition(model, x)) be the way to go then?

sethaxen avatar Feb 20 '23 11:02 sethaxen

Would rand(rng, OrderedDict, condition(model, x)) be the way to go then?

For maximal model-compat, yes. But you do of course take a performance hit as a result :confused:

torfjelde avatar Feb 20 '23 11:02 torfjelde

Hrm. Maybe then predict should use a NamedTuple if x is a NamedTuple (imperfect because you can have few parameters but many data points). Or provide an API for specifying the return type, like rand does (but supporting two optional positional parameters rng and T complicates the interface)

sethaxen avatar Feb 20 '23 11:02 sethaxen

Or provide an API for specifying the return type, like rand does (but supporting two optional positional parameters rng and T complicates the interface)

Adding T to predict (with some default) would be in line with our API for rand though - there type T can be specified already.

devmotion avatar Feb 20 '23 11:02 devmotion