pyro icon indicating copy to clipboard operation
pyro copied to clipboard

`pyro.distributions.transforms.spline_autoregressive` not working with rational quadratic splines

Open robsalomone opened this issue 4 years ago • 1 comments

Issue Description

The helper functions to create transforms in pyro.distributions.transforms.spline_autoregressive throws an error when quadratic splines are used in a simple example. This problem does not occur with linear rational splines.

Environment

Mac OSX 11.3, Python 3.8.

Code Snippet

import pyro.distributions as dist
p = 4 
distZ = dist.Normal(torch.zeros(4), torch.ones(4))
T = pyro.distributions.transforms.spline_autoregressive(input_dim=p, hidden_dims=[25,10],order='quadratic')
distX = dist.TransformedDistribution(distZ, [T])

Error Message

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-94-7c4e42845b4e> in <module>
----> 1 train_flow_model(dataset, distX, params=T.parameters())

<ipython-input-93-7c522f61b451> in train_flow_model(dataset, distX, params, steps, lr)
      5     for step in range(steps):
      6         optimizer.zero_grad()
----> 7         loss = -distX.log_prob(dataset).sum()
      8         loss.backward()
      9         optimizer.step()

~/opt/anaconda3/lib/python3.7/site-packages/torch/distributions/transformed_distribution.py in log_prob(self, value)
    141         y = value
    142         for transform in reversed(self.transforms):
--> 143             x = transform.inv(y)
    144             event_dim += transform.domain.event_dim - transform.codomain.event_dim
    145             log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),

~/opt/anaconda3/lib/python3.7/site-packages/torch/distributions/transforms.py in __call__(self, x)
    247     def __call__(self, x):
    248         assert self._inv is not None
--> 249         return self._inv._inv_call(x)
    250 
    251     def log_abs_det_jacobian(self, x, y):

~/opt/anaconda3/lib/python3.7/site-packages/torch/distributions/transforms.py in _inv_call(self, y)
    159         if y is y_old:
    160             return x_old
--> 161         x = self._inverse(y)
    162         self._cached_x_y = x, y
    163         return x

~/opt/anaconda3/lib/python3.7/site-packages/pyro/distributions/transforms/spline_autoregressive.py in _inverse(self, y)
    116         for _ in range(input_dim):
    117             spline = self.spline.condition(x)
--> 118             x = spline._inverse(y)
    119 
    120         self._cache_log_detJ = spline._cache_log_detJ

~/opt/anaconda3/lib/python3.7/site-packages/pyro/distributions/transforms/spline.py in _inverse(self, y)
    293         otherwise performs the inversion afresh.
    294         """
--> 295         x, log_detJ = self.spline_op(y, inverse=True)
    296         self._cache_log_detJ = -log_detJ
    297         return x

~/opt/anaconda3/lib/python3.7/site-packages/pyro/distributions/transforms/spline.py in spline_op(self, x, **kwargs)
    310 
    311     def spline_op(self, x, **kwargs):
--> 312         w, h, d, l = self._params() if callable(self._params) else self._params
    313         y, log_detJ = _monotonic_rational_spline(x, w, h, d, l, bound=self.bound, **kwargs)
    314         return y, log_detJ

~/opt/anaconda3/lib/python3.7/site-packages/pyro/distributions/transforms/spline.py in _params(self, context)
    489             l = torch.sigmoid(l)
    490         elif self.order == "quadratic":
--> 491             w, h, d = self.nn(context)
    492             l = None
    493         else:

ValueError: too many values to unpack (expected 3)

robsalomone avatar Jul 09 '21 08:07 robsalomone

Hi @robsalomone, I believe normalizing flow development has moved to https://github.com/stefanwebb/flowtorch and then moved to https://github.com/facebookincubator/flowtorch .

@stefanwebb What would you recommend as the best place for current Pyro users to obtain normalizing flow implementations? (trying to avoid following google)

fritzo avatar Jul 09 '21 17:07 fritzo