pyro icon indicating copy to clipboard operation
pyro copied to clipboard

[bug] JitTrace_ELBO fails with AutoNormalizingFlow

Open rafaol opened this issue 3 years ago • 2 comments
trafficstars

Issue Description

Hi all, I'm having issues trying to use normalizing flows with JIT and Pyro's SVI. Code runs fine if I use the standard Trace_ELBO (or even TraceEnum_ELBO in models with discrete variables), but it fails if I replace the ELBO by its JIT-compiled version. The issue seems to be caused by the use of weak references (weakref) when checking unconstrained parameter values. I was wondering if anyone else is having the same issue, and if there's a fix coming up on the horizon.

Environment

  • OS and python version: Ubuntu 20.04, Python 3.9.7.
  • PyTorch version: 1.12.1
  • Pyro version: 1.8.1+06911dc

Code Snippet

Here is a minimal example. The following code runs normally if I use Trace_ELBO as the loss function for SVI, but fails with JitTrace_ELBO.

from functools import partial

import pyro
import pyro.optim
import torch
from pyro import distributions
from pyro.distributions import transforms
from pyro.infer import SVI, JitTrace_ELBO
from pyro.infer.autoguide import AutoNormalizingFlow
from tqdm import trange


def test_model(y):
    m = pyro.sample("m", distributions.MultivariateNormal(torch.zeros(2), torch.eye(2)))
    n = y.shape[0]
    with pyro.plate("observations_plate", n):
        r = pyro.sample("y", distributions.MultivariateNormal(m, 0.01*torch.eye(2)), obs=y)
    return r


if __name__ == "__main__":
    pyro.set_rng_seed(0)

    pyro.clear_param_store()

    transform = partial(transforms.iterated, 1, transforms.block_autoregressive)
    guide = AutoNormalizingFlow(test_model, transform)

    svi = SVI(test_model, guide, pyro.optim.Adam(dict(lr=5e-3)), loss=JitTrace_ELBO())

    n_steps = 1000
    t_iter = trange(n_steps)
    n_data = 100

    test_y = torch.randn(n_data, 2)

    for t in t_iter:
        loss = svi.step(test_y)
        t_iter.set_postfix(loss=loss)

Console output

  0%|          | 0/1000 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/infer/trace_elbo.py", line 250, in loss_and_grads
    loss, surrogate_loss = self.loss_and_surrogate_loss(
  File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/infer/trace_elbo.py", line 239, in loss_and_surrogate_loss
    return self._loss_and_surrogate_loss(*args, **kwargs)
  File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/ops/jit.py", line 107, in __call__
    self.compiled[key] = torch.jit.trace(
  File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/torch/jit/_trace.py", line 795, in trace
    traced = torch._C._create_function_from_trace(
  File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/ops/jit.py", line 96, in compiled
    assert constrained_param.unconstrained() is unconstrained_param
AssertionError

When I check the runtime environment with the debugger, I see that constrained_param.unconstrained() and unconstrained_param match in value, but constrained_param.unconstrained() holds a weakref to a Parameter object, while unconstrained_param points to a Tensor. The gradient's grad_fn also seems to be mismatched.

Values

name: "AutoNormalizingFlow._prototype_tensor"
constrained_param.unconstrained():
Parameter containing:
Parameter containing:
tensor([[-0.1316,  0.0000],
        [ 0.0864,  0.0000],
        [ 0.7393,  0.0000],
        [-0.7574,  0.0000],
        [-0.5140,  0.0000],
        [-0.2067,  0.0000],
        [-0.3183,  0.0000],
        [ 0.7055,  0.0000],
        [-0.5021, -0.3566],
        [-0.5412, -0.7255],
        [-0.4522,  0.6658],
        [ 0.3456,  0.3754],
        [ 0.0407, -0.3971],
        [ 0.1310, -0.7232],
        [-0.5597, -0.3993],
        [ 0.4887,  0.4542]], requires_grad=True)
unconstrained_param:
tensor([[-0.1316,  0.0000],
        [ 0.0864,  0.0000],
        [ 0.7393,  0.0000],
        [-0.7574,  0.0000],
        [-0.5140,  0.0000],
        [-0.2067,  0.0000],
        [-0.3183,  0.0000],
        [ 0.7055,  0.0000],
        [-0.5021, -0.3566],
        [-0.5412, -0.7255],
        [-0.4522,  0.6658],
        [ 0.3456,  0.3754],
        [ 0.0407, -0.3971],
        [ 0.1310, -0.7232],
        [-0.5597, -0.3993],
        [ 0.4887,  0.4542]], grad_fn=<ViewBackward0>)

Pointer addresses:

  • weakref.ref(constrained_param.unconstrained()) = <weakref at 0x7fa425962180; to 'Parameter' at 0x7fa425959040>
  • weakref.ref(unconstrained_param) = <weakref at 0x7fa42408e090; to 'Tensor' at 0x7fa426555090>

rafaol avatar Sep 01 '22 04:09 rafaol

Hi @rafaol, tl;dr don't use the jit, it isn't useful.

I have found that JitTrace_ELBO only works on very simple models, and provides little or no speed improvement. I believe this is because the PyTorch team's goals are more to execute models outside of a Python runtime, in contrast to the JAX team's goals which are more performance oriented. Historically, torch.jit.trace used to slightly speed up some models, but the jit has gradually grown slower so it is no longer useful in Pyro. I've also found that the torch.jit changes so often across PyTorch releases that I can't rely on my jitted models remaining jittable across minor PyTorch releases.

fritzo avatar Sep 01 '22 13:09 fritzo

Thanks @fritzo ! I was actually planning to use JIT on a larger and more complex model (a random effects/hierarchical), which had a significant (10x) speed up in SVI when I tried JIT with a simpler guide (AutoMultivariateNormal and AutoLaplaceApproximation). I wish I could get the same kind of speed ups with normalizing flows, since they're way more flexible and expressive, but I'll have a look at other options within Pyro.

rafaol avatar Sep 01 '22 15:09 rafaol