sbi icon indicating copy to clipboard operation
sbi copied to clipboard

pytorch distribution.support.check() fails for TransformedDistributions

Open rdgao opened this issue 1 year ago • 3 comments

This is not really an sbi bug but a pytorch one, but good to be aware anyway if using custom prior distributions:

in pytorch TransformedDistributions, support is computed as: support = self.transforms[-1].codomain, which ignores any constraints of the base distribution and any intermediate transformations, which is problematic if the base distribution is, e.g., Uniform. This results in accepting posterior samples that are outside of prior bounds when checked in sbiutils.within_support(). This is a known issue in pytorch.

To fix it, one can either:

  1. manually reassigning transforms[-1].codomain to be, e.g., the transformed bounds of the Uniform distribution (hacky but fast), or
  2. in sbiutils.within_support(), call .log_prob() to detect if there are samples that raises an error, which happens if it's outside of the distribution's support (works in general, but ugly and slow).

@michaeldeistler anything else to add?

rdgao avatar Sep 05 '22 15:09 rdgao

Thanks for creating this!

Just to re-iterate: this error will occur, e.g., during .train() of SNPE-C (second round), when prior.log_prob() is called. This gives an error because it is evaluated outside of the prior support.

michaeldeistler avatar Sep 05 '22 16:09 michaeldeistler

And to reproduce:

from torch import ones, zeros
from torch.distributions import AffineTransform, TransformedDistribution, Uniform

base = Uniform(zeros(1), ones(1))
dist = TransformedDistribution(base, AffineTransform(zeros(1), ones(1)))
dist.support.check(100*ones(1))
# -> returns tensor([True])

michaeldeistler avatar Sep 05 '22 16:09 michaeldeistler

and interestingly, if you try to set the codomain based on the suggested solution above (option 1), via dist.transforms[-1].codomain = interval(0, 1)

it throws an error, perhaps appropriately: AttributeError: can't set attribute

another example, where setting the codomain is possible:

from torch import ones, zeros
from torch.distributions import AffineTransform, TransformedDistribution, Uniform, ExpTransform

base = Uniform(zeros(1), ones(1))
dist = TransformedDistribution(base, ExpTransform())
print(dist.support.check(100*ones(1)))

from torch.distributions.constraints import interval
dist.transforms[-1].codomain = interval(0, 1)
print(dist.support.check(100*ones(1)))

rdgao avatar Sep 05 '22 16:09 rdgao