sbi
sbi copied to clipboard
pytorch distribution.support.check() fails for TransformedDistributions
This is not really an sbi bug but a pytorch one, but good to be aware anyway if using custom prior distributions:
in pytorch TransformedDistributions, support is computed as: support = self.transforms[-1].codomain, which ignores any constraints of the base distribution and any intermediate transformations, which is problematic if the base distribution is, e.g., Uniform. This results in accepting posterior samples that are outside of prior bounds when checked in sbiutils.within_support(). This is a known issue in pytorch.
To fix it, one can either:
- manually reassigning
transforms[-1].codomainto be, e.g., the transformed bounds of the Uniform distribution (hacky but fast), or - in
sbiutils.within_support(), call.log_prob()to detect if there are samples that raises an error, which happens if it's outside of the distribution's support (works in general, but ugly and slow).
@michaeldeistler anything else to add?
Thanks for creating this!
Just to re-iterate: this error will occur, e.g., during .train() of SNPE-C (second round), when prior.log_prob() is called. This gives an error because it is evaluated outside of the prior support.
And to reproduce:
from torch import ones, zeros
from torch.distributions import AffineTransform, TransformedDistribution, Uniform
base = Uniform(zeros(1), ones(1))
dist = TransformedDistribution(base, AffineTransform(zeros(1), ones(1)))
dist.support.check(100*ones(1))
# -> returns tensor([True])
and interestingly, if you try to set the codomain based on the suggested solution above (option 1), via
dist.transforms[-1].codomain = interval(0, 1)
it throws an error, perhaps appropriately: AttributeError: can't set attribute
another example, where setting the codomain is possible:
from torch import ones, zeros
from torch.distributions import AffineTransform, TransformedDistribution, Uniform, ExpTransform
base = Uniform(zeros(1), ones(1))
dist = TransformedDistribution(base, ExpTransform())
print(dist.support.check(100*ones(1)))
from torch.distributions.constraints import interval
dist.transforms[-1].codomain = interval(0, 1)
print(dist.support.check(100*ones(1)))