pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Document the role of transforms

Open velochy opened this issue 1 year ago • 7 comments

Describe the issue:

If a transformation is done with an observed variable, it seems to flat out be ignored.

In the example below, the observations are out of the interval given to pm.Normal, but it still fits as if the transformation was not there. Sampling the posterior predictive also ignores the transform, giving values way out of the range.

Transformations are currently not very well documented so I might be misunderstanding something, but having an interval transform keep values within the interval even on observations seems like a sensible expectation so I'm filing this as a bug.

Reproduceable code example:

import pymc as pm
import numpy as np

with pm.Model() as model:
    
    obs = pm.MutableData("obs", np.array([2]*10 + [-2]*10))                
    
    sd = pm.HalfNormal('sd', sigma=2)
    val = pm.Normal('val', mu=0, sigma=sd,    
                     transform=pm.distributions.transforms.Interval(-1,1),
                     observed=obs)
    
    idata = pm.sample()
    pp = pm.sample_posterior_predictive(idata)
    
    pp.posterior_predictive.val.min(), pp.posterior_predictive.val.max()

Error message:

No response

PyMC version information:

PyMC 5.9.2

Context for the issue:

No response

velochy avatar Nov 29 '23 09:11 velochy

Transforms have no role in forward sampling methods (prior/posterior predictive) nor in observed variables during mcmc (pm.sample), since observed variables are fixed then.

I think you're misunderstanding their use. Most users shouldn't bother with transforms, they are there to be able to do mcmc in an unconstrained space, and are not meant to change the meaning of the model (there's some exceptions like Ordered, but that's never a default).

Sounds like you want a Truncated Normal instead, if the bounds are part of your generative model

ricardoV94 avatar Nov 29 '23 11:11 ricardoV94

Ok, thank you for the quick reply. I think truncated normal should solve my actual use case (technically, im using a GMM but switching it to Mixed(TruncatedGaussian) should do the trick)

Would it maybe make sense to update the transformations documentation with a paragraph about when they should or should not be used to avoid future confused developers like me?

velochy avatar Nov 29 '23 12:11 velochy

This being said: the example i demonstrated should not be too hard yo make work by just backward transforming observations and then forward transforming the model outputs for prior/posterior predictive. It would add some interesting flexibility.

velochy avatar Nov 29 '23 12:11 velochy

The transformations are not just a deterministic transformation, they are accompanied by a jacobian so that when doing mcmc sampling, the prior specified by the RV is respected. There's no principled way of doing that with forward sampling.

The forward pass of an interval transformed Normal wouldn't look like neither a Normal nor a TruncatedNormal.

If you want a deterministic transformation that is part of the generative graph you should write it explicitly.

ricardoV94 avatar Nov 29 '23 12:11 ricardoV94

Ok. That would explain the use of exponential and sigmoid in the interval transformation code I noticed :)

Thank you again for clarifying. Maybe still makes sense to add a "mostly for internal use" warning to https://www.pymc.io/projects/docs/en/stable/api/distributions/transforms.html ?

velochy avatar Nov 29 '23 12:11 velochy

Regarding them not working on prior/posterior predictive and this being a surprise to users, there's a proposal to distinguish between default automatic transforms and user defined ones: #5674

We could then issue a warning that they have no effect in those cases like we do for models with Potentials

ricardoV94 avatar Nov 29 '23 12:11 ricardoV94

I think a note on the docs could be super valuable. Explain the normal use cases and what they do and do not.

ricardoV94 avatar Nov 29 '23 13:11 ricardoV94

Hello, first time contributor here. I'd struggled to wrap my head round the transforms and figured that working on this issue would help me learn about them.

In writing the docs I came across some things that are still unclear to me; I thought I'd mention them here either to get feedback on incorporating them into the current PR or perhaps they might become separate issues if need be:

Is there any intended convention for the naming of Transform subclasses?

In some cases, such as LogTransform, the name corresponds to the Transform's forward transformation. In others, such as Ordered, the name corresponds to the backward transformation. I think the lack of clear convention is one of the reasons users get confused when and how to use transforms

Is there a principled distinction between the Transforms defined in pymc.distributions.transforms and those defined in pymc.logprob.transforms?

  • It's not the case that pymc.distributions.transforms is limited to default_transforms. For example, LogExpM1 is not registered as a default for any distribution.
  • It's not the case that pymc.distributions.transforms is the subset of transformations for unconstraining variables (i.e. transformations to the full real line) since Sinh is one such transformation in pymc.logprob.transforms but not in pymc.distributions.transforms
  • The module naming suggests that pymc.distributions.transforms is intended to define distribution families but evidently that's not a principled distinction since a large number of transforms in this module are actually imported from pymc.logprob.transforms
  • If we are saying that Transforms are principally for internal use and nothing stops a power user from passing a transform from pymc.logprob.transforms to the transform argument of a random variable constructor, it's not clear to me what pymc.distributions.transforms is trying to circumscribe

The reason I am wondering if it's a useful distinction is because (1) splitting them between two modules is currently creating issues with sphinx-generated pages, (2) just semantically it makes things harder to reason about and (3) if we are to have a separate subclass of Transforms maybe it needs to be a narrower subclass for things like Ordered which change the generative structure of the model?

What is the reason for exposing both Transform class constructors and instances in the API docs?

Overall I wasn't super clear what the best practices intention was here. Both in terms of what the user should be using (Ordered() vs ordered) and in terms of what should be documented in the API. For example, we had both LogExp1/Ordered and log_exp_m1/ordered but only simplex and not SimplexTransform. In the PR I took the view that we are documenting both (and so added the "missing" constructors) but it actually makes it hard to take a fully consistent approach. For example, I assume one reason SimplexTransform etc weren't documented is because they are defined in pymc.logprob.transforms and so if we now include them should we also include IntervalTransform in addition to Interval? I assume not but it highlights the ambiguity.

mkusnetsov avatar Mar 29 '24 23:03 mkusnetsov

Is there any intended convention for the naming of Transform subclasses?

I don't think so. What you describe looks like a difference in emphasis between 1) transforms used to unconstrain sampling and 2) transforms used to distort sampling.

Is there a principled distinction between the Transforms defined in pymc.distributions.transforms and those defined in pymc.logprob.transforms?

This is partly a legacy issue, partly a partial (no pun) overlap issue. The transforms in distribution.transforms are the user facing methods. Some of these correspond to a subset of the transforms from the logprob module, which are used in more internal contexts (such as logprob inference). There are also some transforms that are only useful for users, like Ordered, which are defined only in distribution.transforms.

TLDR: Transforms needed for internal consumption are defined in the logprob module. Transforms needed for user consumption are defined or made available in the distributions module. Users shouldn't have to know about those in the logprob module, but it won't hurt them if they stumble across them accidentally either.

What is the reason for exposing both Transform class constructors and instances in the API docs?

Some transforms have to be operationalized (e.g., the ZeroSumAxes requires knowing the number of axes of summation). Other transforms can be used always the same way, so we instantiate them once for users. There was some forth and back during recent times, so there may be cases that are not up to date.

ricardoV94 avatar Apr 05 '24 12:04 ricardoV94