sbi
sbi copied to clipboard
pytorch distribution.support.check() fails for TransformedDistributions
This is not really an sbi
bug but a pytorch
one, but good to be aware anyway if using custom prior distributions:
in pytorch
TransformedDistributions, support
is computed as: support = self.transforms[-1].codomain
, which ignores any constraints of the base distribution and any intermediate transformations, which is problematic if the base distribution is, e.g., Uniform
. This results in accepting posterior samples that are outside of prior bounds when checked in sbiutils.within_support()
. This is a known issue in pytorch.
To fix it, one can either:
- manually reassigning
transforms[-1].codomain
to be, e.g., the transformed bounds of the Uniform distribution (hacky but fast), or - in
sbiutils.within_support()
, call.log_prob()
to detect if there are samples that raises an error, which happens if it's outside of the distribution's support (works in general, but ugly and slow).
@michaeldeistler anything else to add?
Thanks for creating this!
Just to re-iterate: this error will occur, e.g., during .train()
of SNPE-C (second round), when prior.log_prob()
is called. This gives an error because it is evaluated outside of the prior
support.
And to reproduce:
from torch import ones, zeros
from torch.distributions import AffineTransform, TransformedDistribution, Uniform
base = Uniform(zeros(1), ones(1))
dist = TransformedDistribution(base, AffineTransform(zeros(1), ones(1)))
dist.support.check(100*ones(1))
# -> returns tensor([True])
and interestingly, if you try to set the codomain based on the suggested solution above (option 1), via
dist.transforms[-1].codomain = interval(0, 1)
it throws an error, perhaps appropriately: AttributeError: can't set attribute
another example, where setting the codomain is possible:
from torch import ones, zeros
from torch.distributions import AffineTransform, TransformedDistribution, Uniform, ExpTransform
base = Uniform(zeros(1), ones(1))
dist = TransformedDistribution(base, ExpTransform())
print(dist.support.check(100*ones(1)))
from torch.distributions.constraints import interval
dist.transforms[-1].codomain = interval(0, 1)
print(dist.support.check(100*ones(1)))