pyro
pyro copied to clipboard
[FR] TransformModule.inv is not a TransformModule
Issue Description
Calling the .inv
property of a TransformModule
does not return an instance of a TransformModule.
from pyro.distributions import TransformModule
from torch.distributions import Transform
my_module = TransformModule()
isinstance(my_module, TransformModule) # True
isinstance(my_module, Transform) # True
isinstance(my_module.inv, TransformModule) # False
isinstance(my_module.inv, Transform) # True
As a result, the inverse of any TransformModule
can not be optimized directly (as the parameters are not accessible).
Current solution
I am currently using an InverseTranform
module that is an instance of TransformModule to invert my transformations without losing the parameter dependency
from torch.distributions import constraints
from pyro.distributions import TransformModule
class InverseTransform(TransformModule):
'''
This code follows _InverseTransform from torch.distributions.transforms
The differences are that we extend TransformModule instead of Transform,
override the __repr__() method to show the name of the original (inverted) module
and keep the volume_preserving attribute
'''
def __init__(self, transform):
super(InverseTransform, self).__init__(cache_size=transform._cache_size)
self._inv = transform
# Patch to carry over the volume_preserving attribute
if hasattr(transform, 'volume_preserving'):
self.volume_preserving = transform.volume_preserving
@constraints.dependent_property
def domain(self):
return self._inv.codomain
@constraints.dependent_property
def codomain(self):
return self._inv.domain
@property
def bijective(self):
return self._inv.bijective
@property
def sign(self):
return self._inv.sign
@property
def event_dim(self):
return self._inv.event_dim
@property
def inv(self):
return self._inv
def _inverse(self, y):
return self._inv._call(y)
def _call(self, x):
return self._inv._inverse(x)
def with_cache(self, cache_size=1):
return self.inv.with_cache(cache_size).inv
def __eq__(self, other):
if not isinstance(other, InverseTransform):
return False
return self._inv == other._inv
def __call__(self, x):
return self._inv._inv_call(x)
def log_abs_det_jacobian(self, x, y):
return -self._inv.log_abs_det_jacobian(y, x)
# Patch to show the representation of the inverted module
def __repr__(self) -> str:
return 'Inverse('+ self._inv.__repr__() +')'
def __hash__(self):
return self._inv.__hash__()
Usage
from pyro.distributions import TransformModule
from torch.distributions import Transform
my_module = TransformModule()
my_inverted_module = InverseTransform(mymodule)
isinstance(my_inverted_module, TransformModule) # True
Changing the behavior of the .inv
attribute would require overriding the TransformModule
class.
Hi @mfederici, thanks for your interest in the normalizing flow sublibrary! :)
Yes, this is a bug we're aware of and something that needs to be fixed. I've moved NF development into a separate library and am not developing this part of the Pyro code further myself.
I think the short term fix is to implement the change like you have, or else getting the parameters directly from my_module
(which are the same for my_inverted_module
.)
The longer term fix is to wait until our new library has been released - I think it is very likely before the end of the year.
Would you be interested in a patch for this as a PR?