numpyro
numpyro copied to clipboard
pattern for creating a TransformModule
This issue discusses possible ways to have something similar to Pyro's TransformModule. The motivation is to move the construction of a flow nn to the construction of a flow transform, so the code is cleaner.
Current approach
# create a flow nn
nn_init, nn_apply = nn = FlowNN(flow_args)
# set params to nn_apply
nn_call = numpyro.module("nn", nn) # or partial(nn_apply, params)
# create a transform from `nn_call`
transform = FlowTransform(nn_call)
These steps are explicit, and that FlowTransform does not have to care about params stuff at all.
Approach 2
# create a transform module, this is not a Transform
transform_nn = FlowTransform(flow_args)
# set params to transform_nn, convert it to a Transform
transform = numpyro.module("transform_nn", transform_nn)
This transform_nn is not useful at all. It is only useful if we register params for it. I guess that something like
class TransformModule(...):
def __init__(self):
self._nn_init, self._nn_apply = FlowNN(flow_args)
self.params = None
def logdet(self, x, y, ...):
if self.params is None:
raise Error
nn_call = partial(self._nn_apply, self.params)
...
and in module handler, we set
transform.params = params
will work.
Approach 3
Similar to approach 2, but FlowTransform(flow_args) return a pair of functions transform_init, transform_apply, where transform_init is used to create params, and transform_apply is used to return a Transform given a params. We will need to adjust the behavior of the module handler a bit to distinguish the usual nn layer and this nn transform layer.
I prefer approach 1 and 3 both for their explicit and more functional styles. But this issue is open to other suggestions.
Is this being worked on currently? Would be great to see normalising flows in numpyro! (Am I right in thinking this is a prerequisite for that? Or is there a different approach)
@stefanwebb is redesigning the normalizing flows interface for PyTorch-Pyro in a new library. We might want to wait a bit to see what new abstractions he decides on.
Oh neat. I'll check it out. Thanks!
Hi, I just wanted to know if this was being worked or will it be worked on in the near future? I see the repo referenced has been archived so I was wondering what were numpyro (and also pyro) plans for further development for normalizing flow methods. I would be interested in contributing to both.
Hi @jejjohnson, the new library is https://github.com/facebookincubator/flowtorch and is being developed by @stefanwebb @feynmanliang at Facebook. Perhaps they can comment on open source timeline and invitation-only access for early adopters.
Hi @fritzo, thank you for the information! A side question: the flowtorch library will be for PyTorch (at least it was the last time I checked). Are there plans to perhaps borrow a few methods for the numpyro library?
@jejjohnson I know of no plans to port Pyro normalizing flows to NumPyro, but contributions are always welcome 😄