DynamicPPL.jl
DynamicPPL.jl copied to clipboard
Enhance wrapped distributions
Add basic WrappedDistribution
type for NoDist
and NamedDist
and teach them a few tricks like length()
and bijector()
.
I've discovered that these methods are missing when trying to do
DynamicPPL.tilde_assume!!(context, NoDist(prior), @varname(v), varinfo)
where prior
was a Product
multivariate. With the changes implemented in this PR it is working.
I'm slightly worried about the additional complexity introduced by the new abstract type and functions such as wrapped_dist
and wrapped_dist_type
. Can't we just add whatever definition was missing?
In general, both distributions are only used internally in DynamicPPL and hence only the parts of the Distributions API relevant for DynamicPPL are implemented. What exactly was missing? Did you actually try to call tilde_assume!!
directly?
I'm slightly worried about the additional complexity introduced by the new abstract type and functions such as wrapped_dist and wrapped_dist_type. Can't we just add whatever definition was missing?
It's just one abstract type and a very few standard boilerplate defs around it (wrapped_distr()
etc). OTOH it allows to avoid the duplication of method definitions like length()
etc. I see your point, but I think both approaches have advantages in terms of maintenance. Before this patch I had errors about length()
and bijector()
missing for NoDist
, but I can see how more methods from Distributions API might be required in the future, so this PR makes it easier to add them.
Did you actually try to call tilde_assume!! directly?
Yes, I'm not using @model
macro, I'm using DynamicPPL
directly to have more control and flexibility in statistical models generation.
It's just one abstract type and a very few standard boilerplate defs around it (
wrapped_distr()
etc). OTOH it allows to avoid the duplication of method definitions likelength()
etc. I see your point, but I think both approaches have advantages in terms of maintenance. Before this patch I had errors aboutlength()
andbijector()
missing forNoDist
, but I can see how more methods from Distributions API might be required in the future, so this PR makes it easier to add them.
I can see that point, but I'm probably biased here towards not adding additional types and things that are potentially useful at some point in the future due to the history of DynamicPPL, and VarInfo
in particular: At this point it is really unclear what methods in varinfo.jl are needed, useful or should be removed. That even motivated a complete refactor and rewrite but it is still messy.
So my suggestion would be
- add a MWE to the tests that is currently failing
- and add only the missing definitions that make the test pass.
Did you actually try to call tilde_assume!! directly?
It would be interesting to know if that can be reproduced with a regular @model
as well, or if there is some problem with how tilde_assume!!
was called.
Can we add at least tests for every new function and type and fix the CI errors?
And I think it would be nice to see as well what actually went wrong and what has to be fixed.
Oh it seems maybe @torfjelde has already fixed the problems in https://github.com/TuringLang/DynamicPPL.jl/pull/360/commits/0f9765bda684b27202982cf95d11e8de07304f62?
Oh it seems maybe @torfjelde has already fixed the problems in ...
It doesn't define the bijector for NoDist
though.
I've added MWE to the tests.
It doesn't define the bijector for NoDist though.
I actually didn't do this deliberately because I'm uncertain if we ever want to hit this. NoDist
should represent "don't do anything with this variable", but if we at some point hit bijector(nodist)
, then this indicates that we might be trying to compute the logabsdetjac
correction which actually shouldn't be included in the log-joint computation :confused:
So are we certain adding this implementation isn't doing something silently incorrect?
EDIT: See https://github.com/TuringLang/DynamicPPL.jl/pull/414#discussion_r908428665
I was just looking at https://github.com/TuringLang/Turing.jl/blob/master/src/stdlib/distributions.jl for completely unrelated reasons, and discovered
some definitions of Bijectors.logpdf_with_trans(::NoDist, x, t)
:open_mouth:
Regardless of whether they are useful etc., this seems like one of the worst places to hide them :smile:
I was just looking at https://github.com/TuringLang/Turing.jl/blob/master/src/stdlib/distributions.jl for completely unrelated reasons, and discovered some definitions of
Bijectors.logpdf_with_trans(::NoDist, x, t)
open_mouthRegardless of whether they are useful etc., this seems like one of the worst places to hide them smile
Those shouldn't be there :flushed:
bors try
bors try
bors try
@alyst Very sorry for the delay; looks like tests aren't passing ATM.
Maybe I missed something (haven't checked this PR for a while) but I think @torfjelde's and my concerns above are still valid?