pyro icon indicating copy to clipboard operation
pyro copied to clipboard

[FR] TransformModule.inv is not a TransformModule

Open mfederici opened this issue 3 years ago • 2 comments

Issue Description

Calling the .inv property of a TransformModule does not return an instance of a TransformModule.

from pyro.distributions import TransformModule
from torch.distributions import Transform

my_module = TransformModule()

isinstance(my_module, TransformModule)                 # True
isinstance(my_module, Transform)                       # True
isinstance(my_module.inv, TransformModule)             # False
isinstance(my_module.inv, Transform)                   # True

As a result, the inverse of any TransformModule can not be optimized directly (as the parameters are not accessible).

Current solution

I am currently using an InverseTranform module that is an instance of TransformModule to invert my transformations without losing the parameter dependency

from torch.distributions import constraints
from pyro.distributions import TransformModule

class InverseTransform(TransformModule):
    '''
    This code follows _InverseTransform from torch.distributions.transforms
    The differences are that we extend TransformModule instead of Transform,
    override the __repr__() method to show the name of the original (inverted) module
    and keep the volume_preserving attribute
    '''
    def __init__(self, transform):
        super(InverseTransform, self).__init__(cache_size=transform._cache_size)
        self._inv = transform
        
        # Patch to carry over the volume_preserving attribute
        if hasattr(transform, 'volume_preserving'):
            self.volume_preserving = transform.volume_preserving

    @constraints.dependent_property
    def domain(self):
        return self._inv.codomain

    @constraints.dependent_property
    def codomain(self):
        return self._inv.domain

    @property
    def bijective(self):
        return self._inv.bijective

    @property
    def sign(self):
        return self._inv.sign

    @property
    def event_dim(self):
        return self._inv.event_dim

    @property
    def inv(self):
        return self._inv

    def _inverse(self, y):
        return self._inv._call(y)

    def _call(self, x):
        return self._inv._inverse(x)

    def with_cache(self, cache_size=1):
        return self.inv.with_cache(cache_size).inv

    def __eq__(self, other):
        if not isinstance(other, InverseTransform):
            return False
        return self._inv == other._inv

    def __call__(self, x):
        return self._inv._inv_call(x)

    def log_abs_det_jacobian(self, x, y):
        return -self._inv.log_abs_det_jacobian(y, x)
    
    # Patch to show the representation of the inverted module
    def __repr__(self) -> str:
        return 'Inverse('+ self._inv.__repr__() +')'

    def __hash__(self):
        return self._inv.__hash__()

Usage

from pyro.distributions import TransformModule
from torch.distributions import Transform

my_module = TransformModule()
my_inverted_module = InverseTransform(mymodule)

isinstance(my_inverted_module, TransformModule)             # True

Changing the behavior of the .inv attribute would require overriding the TransformModule class.

mfederici avatar Oct 05 '21 16:10 mfederici

Hi @mfederici, thanks for your interest in the normalizing flow sublibrary! :)

Yes, this is a bug we're aware of and something that needs to be fixed. I've moved NF development into a separate library and am not developing this part of the Pyro code further myself.

I think the short term fix is to implement the change like you have, or else getting the parameters directly from my_module (which are the same for my_inverted_module.)

The longer term fix is to wait until our new library has been released - I think it is very likely before the end of the year.

stefanwebb avatar Oct 05 '21 19:10 stefanwebb

Would you be interested in a patch for this as a PR?

felixdivo avatar Feb 10 '23 15:02 felixdivo