pyro
pyro copied to clipboard
[bug] JitTrace_ELBO fails with AutoNormalizingFlow
Issue Description
Hi all, I'm having issues trying to use normalizing flows with JIT and Pyro's SVI. Code runs fine if I use the standard Trace_ELBO (or even TraceEnum_ELBO in models with discrete variables), but it fails if I replace the ELBO by its JIT-compiled version. The issue seems to be caused by the use of weak references (weakref) when checking unconstrained parameter values. I was wondering if anyone else is having the same issue, and if there's a fix coming up on the horizon.
Environment
- OS and python version: Ubuntu 20.04, Python 3.9.7.
- PyTorch version: 1.12.1
- Pyro version: 1.8.1+06911dc
Code Snippet
Here is a minimal example. The following code runs normally if I use Trace_ELBO as the loss function for SVI, but fails with JitTrace_ELBO.
from functools import partial
import pyro
import pyro.optim
import torch
from pyro import distributions
from pyro.distributions import transforms
from pyro.infer import SVI, JitTrace_ELBO
from pyro.infer.autoguide import AutoNormalizingFlow
from tqdm import trange
def test_model(y):
m = pyro.sample("m", distributions.MultivariateNormal(torch.zeros(2), torch.eye(2)))
n = y.shape[0]
with pyro.plate("observations_plate", n):
r = pyro.sample("y", distributions.MultivariateNormal(m, 0.01*torch.eye(2)), obs=y)
return r
if __name__ == "__main__":
pyro.set_rng_seed(0)
pyro.clear_param_store()
transform = partial(transforms.iterated, 1, transforms.block_autoregressive)
guide = AutoNormalizingFlow(test_model, transform)
svi = SVI(test_model, guide, pyro.optim.Adam(dict(lr=5e-3)), loss=JitTrace_ELBO())
n_steps = 1000
t_iter = trange(n_steps)
n_data = 100
test_y = torch.randn(n_data, 2)
for t in t_iter:
loss = svi.step(test_y)
t_iter.set_postfix(loss=loss)
Console output
0%| | 0/1000 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/infer/svi.py", line 145, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/infer/trace_elbo.py", line 250, in loss_and_grads
loss, surrogate_loss = self.loss_and_surrogate_loss(
File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/infer/trace_elbo.py", line 239, in loss_and_surrogate_loss
return self._loss_and_surrogate_loss(*args, **kwargs)
File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/ops/jit.py", line 107, in __call__
self.compiled[key] = torch.jit.trace(
File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/torch/jit/_trace.py", line 795, in trace
traced = torch._C._create_function_from_trace(
File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/ops/jit.py", line 96, in compiled
assert constrained_param.unconstrained() is unconstrained_param
AssertionError
When I check the runtime environment with the debugger, I see that constrained_param.unconstrained() and unconstrained_param match in value, but constrained_param.unconstrained() holds a weakref to a Parameter object, while unconstrained_param points to a Tensor. The gradient's grad_fn also seems to be mismatched.
Values
name: "AutoNormalizingFlow._prototype_tensor"
constrained_param.unconstrained():
Parameter containing:
Parameter containing:
tensor([[-0.1316, 0.0000],
[ 0.0864, 0.0000],
[ 0.7393, 0.0000],
[-0.7574, 0.0000],
[-0.5140, 0.0000],
[-0.2067, 0.0000],
[-0.3183, 0.0000],
[ 0.7055, 0.0000],
[-0.5021, -0.3566],
[-0.5412, -0.7255],
[-0.4522, 0.6658],
[ 0.3456, 0.3754],
[ 0.0407, -0.3971],
[ 0.1310, -0.7232],
[-0.5597, -0.3993],
[ 0.4887, 0.4542]], requires_grad=True)
unconstrained_param:
tensor([[-0.1316, 0.0000],
[ 0.0864, 0.0000],
[ 0.7393, 0.0000],
[-0.7574, 0.0000],
[-0.5140, 0.0000],
[-0.2067, 0.0000],
[-0.3183, 0.0000],
[ 0.7055, 0.0000],
[-0.5021, -0.3566],
[-0.5412, -0.7255],
[-0.4522, 0.6658],
[ 0.3456, 0.3754],
[ 0.0407, -0.3971],
[ 0.1310, -0.7232],
[-0.5597, -0.3993],
[ 0.4887, 0.4542]], grad_fn=<ViewBackward0>)
Pointer addresses:
weakref.ref(constrained_param.unconstrained())=<weakref at 0x7fa425962180; to 'Parameter' at 0x7fa425959040>weakref.ref(unconstrained_param)=<weakref at 0x7fa42408e090; to 'Tensor' at 0x7fa426555090>
Hi @rafaol, tl;dr don't use the jit, it isn't useful.
I have found that JitTrace_ELBO only works on very simple models, and provides little or no speed improvement. I believe this is because the PyTorch team's goals are more to execute models outside of a Python runtime, in contrast to the JAX team's goals which are more performance oriented. Historically, torch.jit.trace used to slightly speed up some models, but the jit has gradually grown slower so it is no longer useful in Pyro. I've also found that the torch.jit changes so often across PyTorch releases that I can't rely on my jitted models remaining jittable across minor PyTorch releases.
Thanks @fritzo ! I was actually planning to use JIT on a larger and more complex model (a random effects/hierarchical), which had a significant (10x) speed up in SVI when I tried JIT with a simpler guide (AutoMultivariateNormal and AutoLaplaceApproximation). I wish I could get the same kind of speed ups with normalizing flows, since they're way more flexible and expressive, but I'll have a look at other options within Pyro.