DynamicPPL.jl icon indicating copy to clipboard operation
DynamicPPL.jl copied to clipboard

Enhance wrapped distributions

Open alyst opened this issue 2 years ago • 17 comments

Add basic WrappedDistribution type for NoDist and NamedDist and teach them a few tricks like length() and bijector(). I've discovered that these methods are missing when trying to do

DynamicPPL.tilde_assume!!(context, NoDist(prior), @varname(v), varinfo)

where prior was a Product multivariate. With the changes implemented in this PR it is working.

alyst avatar Jun 25 '22 06:06 alyst

I'm slightly worried about the additional complexity introduced by the new abstract type and functions such as wrapped_dist and wrapped_dist_type. Can't we just add whatever definition was missing?

In general, both distributions are only used internally in DynamicPPL and hence only the parts of the Distributions API relevant for DynamicPPL are implemented. What exactly was missing? Did you actually try to call tilde_assume!! directly?

devmotion avatar Jun 25 '22 19:06 devmotion

I'm slightly worried about the additional complexity introduced by the new abstract type and functions such as wrapped_dist and wrapped_dist_type. Can't we just add whatever definition was missing?

It's just one abstract type and a very few standard boilerplate defs around it (wrapped_distr() etc). OTOH it allows to avoid the duplication of method definitions like length() etc. I see your point, but I think both approaches have advantages in terms of maintenance. Before this patch I had errors about length() and bijector() missing for NoDist, but I can see how more methods from Distributions API might be required in the future, so this PR makes it easier to add them.

Did you actually try to call tilde_assume!! directly?

Yes, I'm not using @model macro, I'm using DynamicPPL directly to have more control and flexibility in statistical models generation.

alyst avatar Jun 25 '22 19:06 alyst

It's just one abstract type and a very few standard boilerplate defs around it (wrapped_distr() etc). OTOH it allows to avoid the duplication of method definitions like length() etc. I see your point, but I think both approaches have advantages in terms of maintenance. Before this patch I had errors about length() and bijector() missing for NoDist, but I can see how more methods from Distributions API might be required in the future, so this PR makes it easier to add them.

I can see that point, but I'm probably biased here towards not adding additional types and things that are potentially useful at some point in the future due to the history of DynamicPPL, and VarInfo in particular: At this point it is really unclear what methods in varinfo.jl are needed, useful or should be removed. That even motivated a complete refactor and rewrite but it is still messy.

So my suggestion would be

  • add a MWE to the tests that is currently failing
  • and add only the missing definitions that make the test pass.

Did you actually try to call tilde_assume!! directly?

It would be interesting to know if that can be reproduced with a regular @model as well, or if there is some problem with how tilde_assume!! was called.

devmotion avatar Jun 25 '22 19:06 devmotion

Can we add at least tests for every new function and type and fix the CI errors?

And I think it would be nice to see as well what actually went wrong and what has to be fixed.

devmotion avatar Jun 27 '22 19:06 devmotion

Oh it seems maybe @torfjelde has already fixed the problems in https://github.com/TuringLang/DynamicPPL.jl/pull/360/commits/0f9765bda684b27202982cf95d11e8de07304f62?

devmotion avatar Jun 27 '22 19:06 devmotion

Oh it seems maybe @torfjelde has already fixed the problems in ...

It doesn't define the bijector for NoDist though.

I've added MWE to the tests.

alyst avatar Jun 27 '22 21:06 alyst

It doesn't define the bijector for NoDist though.

I actually didn't do this deliberately because I'm uncertain if we ever want to hit this. NoDist should represent "don't do anything with this variable", but if we at some point hit bijector(nodist), then this indicates that we might be trying to compute the logabsdetjac correction which actually shouldn't be included in the log-joint computation :confused:

So are we certain adding this implementation isn't doing something silently incorrect?

EDIT: See https://github.com/TuringLang/DynamicPPL.jl/pull/414#discussion_r908428665

torfjelde avatar Jun 28 '22 12:06 torfjelde

I was just looking at https://github.com/TuringLang/Turing.jl/blob/master/src/stdlib/distributions.jl for completely unrelated reasons, and discovered some definitions of Bijectors.logpdf_with_trans(::NoDist, x, t) :open_mouth:

Regardless of whether they are useful etc., this seems like one of the worst places to hide them :smile:

devmotion avatar Jun 29 '22 21:06 devmotion

I was just looking at https://github.com/TuringLang/Turing.jl/blob/master/src/stdlib/distributions.jl for completely unrelated reasons, and discovered some definitions of Bijectors.logpdf_with_trans(::NoDist, x, t) open_mouth

Regardless of whether they are useful etc., this seems like one of the worst places to hide them smile

Those shouldn't be there :flushed:

torfjelde avatar Jun 30 '22 10:06 torfjelde

bors try

alyst avatar Oct 01 '22 05:10 alyst

:lock: Permission denied

Existing reviewers: click here to make alyst a reviewer

bors[bot] avatar Oct 01 '22 05:10 bors[bot]

bors try

ParadaCarleton avatar Dec 19 '22 19:12 ParadaCarleton

try

Build failed:

bors[bot] avatar Dec 19 '22 20:12 bors[bot]

bors try

ParadaCarleton avatar Dec 19 '22 21:12 ParadaCarleton

@alyst Very sorry for the delay; looks like tests aren't passing ATM.

ParadaCarleton avatar Dec 19 '22 21:12 ParadaCarleton

try

Build failed:

bors[bot] avatar Dec 19 '22 22:12 bors[bot]

Maybe I missed something (haven't checked this PR for a while) but I think @torfjelde's and my concerns above are still valid?

devmotion avatar Dec 21 '22 13:12 devmotion