numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

pattern for creating a TransformModule

Open fehiepsi opened this issue 5 years ago • 7 comments

This issue discusses possible ways to have something similar to Pyro's TransformModule. The motivation is to move the construction of a flow nn to the construction of a flow transform, so the code is cleaner.

Current approach

# create a flow nn
nn_init, nn_apply = nn = FlowNN(flow_args)
# set params to nn_apply
nn_call = numpyro.module("nn", nn)  # or partial(nn_apply, params)
# create a transform from `nn_call`
transform = FlowTransform(nn_call)

These steps are explicit, and that FlowTransform does not have to care about params stuff at all.

Approach 2

# create a transform module, this is not a Transform
transform_nn = FlowTransform(flow_args)
# set params to transform_nn, convert it to a Transform
transform = numpyro.module("transform_nn", transform_nn)

This transform_nn is not useful at all. It is only useful if we register params for it. I guess that something like

class TransformModule(...):
    def __init__(self):
        self._nn_init, self._nn_apply = FlowNN(flow_args)
        self.params = None

    def logdet(self, x, y, ...):
        if self.params is None:
            raise Error
        nn_call = partial(self._nn_apply, self.params)
        ...

and in module handler, we set

transform.params = params

will work.

Approach 3

Similar to approach 2, but FlowTransform(flow_args) return a pair of functions transform_init, transform_apply, where transform_init is used to create params, and transform_apply is used to return a Transform given a params. We will need to adjust the behavior of the module handler a bit to distinguish the usual nn layer and this nn transform layer.

I prefer approach 1 and 3 both for their explicit and more functional styles. But this issue is open to other suggestions.

fehiepsi avatar Jan 15 '20 15:01 fehiepsi

Is this being worked on currently? Would be great to see normalising flows in numpyro! (Am I right in thinking this is a prerequisite for that? Or is there a different approach)

tcbegley avatar Dec 17 '20 10:12 tcbegley

@stefanwebb is redesigning the normalizing flows interface for PyTorch-Pyro in a new library. We might want to wait a bit to see what new abstractions he decides on.

fritzo avatar Dec 17 '20 13:12 fritzo

Oh neat. I'll check it out. Thanks!

tcbegley avatar Dec 17 '20 14:12 tcbegley

Hi, I just wanted to know if this was being worked or will it be worked on in the near future? I see the repo referenced has been archived so I was wondering what were numpyro (and also pyro) plans for further development for normalizing flow methods. I would be interested in contributing to both.

jejjohnson avatar Sep 28 '21 10:09 jejjohnson

Hi @jejjohnson, the new library is https://github.com/facebookincubator/flowtorch and is being developed by @stefanwebb @feynmanliang at Facebook. Perhaps they can comment on open source timeline and invitation-only access for early adopters.

fritzo avatar Sep 28 '21 12:09 fritzo

Hi @fritzo, thank you for the information! A side question: the flowtorch library will be for PyTorch (at least it was the last time I checked). Are there plans to perhaps borrow a few methods for the numpyro library?

jejjohnson avatar Sep 28 '21 12:09 jejjohnson

@jejjohnson I know of no plans to port Pyro normalizing flows to NumPyro, but contributions are always welcome 😄

fritzo avatar Sep 28 '21 12:09 fritzo