pyro
pyro copied to clipboard
`pyro.distributions.transforms.spline_autoregressive` not working with rational quadratic splines
Issue Description
The helper functions to create transforms in pyro.distributions.transforms.spline_autoregressive throws an error when quadratic splines are used in a simple example. This problem does not occur with linear rational splines.
Environment
Mac OSX 11.3, Python 3.8.
Code Snippet
import pyro.distributions as dist
p = 4
distZ = dist.Normal(torch.zeros(4), torch.ones(4))
T = pyro.distributions.transforms.spline_autoregressive(input_dim=p, hidden_dims=[25,10],order='quadratic')
distX = dist.TransformedDistribution(distZ, [T])
Error Message
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-94-7c4e42845b4e> in <module>
----> 1 train_flow_model(dataset, distX, params=T.parameters())
<ipython-input-93-7c522f61b451> in train_flow_model(dataset, distX, params, steps, lr)
5 for step in range(steps):
6 optimizer.zero_grad()
----> 7 loss = -distX.log_prob(dataset).sum()
8 loss.backward()
9 optimizer.step()
~/opt/anaconda3/lib/python3.7/site-packages/torch/distributions/transformed_distribution.py in log_prob(self, value)
141 y = value
142 for transform in reversed(self.transforms):
--> 143 x = transform.inv(y)
144 event_dim += transform.domain.event_dim - transform.codomain.event_dim
145 log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
~/opt/anaconda3/lib/python3.7/site-packages/torch/distributions/transforms.py in __call__(self, x)
247 def __call__(self, x):
248 assert self._inv is not None
--> 249 return self._inv._inv_call(x)
250
251 def log_abs_det_jacobian(self, x, y):
~/opt/anaconda3/lib/python3.7/site-packages/torch/distributions/transforms.py in _inv_call(self, y)
159 if y is y_old:
160 return x_old
--> 161 x = self._inverse(y)
162 self._cached_x_y = x, y
163 return x
~/opt/anaconda3/lib/python3.7/site-packages/pyro/distributions/transforms/spline_autoregressive.py in _inverse(self, y)
116 for _ in range(input_dim):
117 spline = self.spline.condition(x)
--> 118 x = spline._inverse(y)
119
120 self._cache_log_detJ = spline._cache_log_detJ
~/opt/anaconda3/lib/python3.7/site-packages/pyro/distributions/transforms/spline.py in _inverse(self, y)
293 otherwise performs the inversion afresh.
294 """
--> 295 x, log_detJ = self.spline_op(y, inverse=True)
296 self._cache_log_detJ = -log_detJ
297 return x
~/opt/anaconda3/lib/python3.7/site-packages/pyro/distributions/transforms/spline.py in spline_op(self, x, **kwargs)
310
311 def spline_op(self, x, **kwargs):
--> 312 w, h, d, l = self._params() if callable(self._params) else self._params
313 y, log_detJ = _monotonic_rational_spline(x, w, h, d, l, bound=self.bound, **kwargs)
314 return y, log_detJ
~/opt/anaconda3/lib/python3.7/site-packages/pyro/distributions/transforms/spline.py in _params(self, context)
489 l = torch.sigmoid(l)
490 elif self.order == "quadratic":
--> 491 w, h, d = self.nn(context)
492 l = None
493 else:
ValueError: too many values to unpack (expected 3)
Hi @robsalomone, I believe normalizing flow development has moved to https://github.com/stefanwebb/flowtorch and then moved to https://github.com/facebookincubator/flowtorch .
@stefanwebb What would you recommend as the best place for current Pyro users to obtain normalizing flow implementations? (trying to avoid following google)