pyro icon indicating copy to clipboard operation
pyro copied to clipboard

RuntimeError: Graph::copy() with pyro on pytorch 1.12.1 (GPU)

Open Balandat opened this issue 3 years ago • 10 comments

Issue Description

When running NUTS on a pretty basic model using 1.12.1+cu113, the following exception is raised (for a full repro see the colab link below):

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-f2613a8adf9b> in <module>
     21     disable_progbar=disable_progbar,
     22 )
---> 23 mcmc.run()

12 frames
/usr/local/lib/python3.7/dist-packages/pyro/infer/mcmc/util.py in _potential_fn_jit(self, skip_jit_warnings, jit_options, params)
    292 
    293         if self._compiled_fn:
--> 294             return self._compiled_fn(*vals)
    295 
    296         with pyro.validation_enabled(False):

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Graph::copy() encountered a use of a value 133 not in scope. Run lint!

This seems to be an issue only on pytorch 1.12.1; when downgrading pytorch to 1.11 this error disappears.

Note: This issue was originally raised in https://github.com/facebook/Ax/issues/1108, and there is some more discussion there. This issue has a minimal repro reduced to the pyro model.

Environment

For any bugs, please provide the following:

  • google colab (though reproducible on other environments)
  • 1.12.1+cu113
  • Pyro version: output of 1.8.2

Code Snippet

https://colab.research.google.com/drive/1mRY01XShwIb06aNj-UgKRd7C4av1jqTU (make sure to run on GPU runtime)

Balandat avatar Sep 07 '22 04:09 Balandat

@Balandat as you're aware, this is almost certainly a torch jit issue. unfortunately we find torch fit to be quite brittle from release to release. apart from downgrading torch, i expect passing jit_compile=False would also fix things. i don't know how much speed-up jit_compile=True gives you anyway. it's been many release cycles since we've seen torch jit deliver larger speed-ups

martinjankowiak avatar Sep 07 '22 13:09 martinjankowiak

I can confirm that jit_compile=False fixes things - I'll look into the what the speedup from this is.

Balandat avatar Sep 07 '22 14:09 Balandat

cc @dme65

Balandat avatar Sep 07 '22 16:09 Balandat

I can confirm that jit_compile=False fixes things - I'll look into the what the speedup from this is.

Disabling jit worked for @xingchenwan as well who was running into the same issue. I recall jit giving pretty decent speed-ups for a very small number of datapoints, but that was also over a year ago so it makes sense to take another look and see if jit makes a difference on the latest versions.

dme65 avatar Sep 07 '22 17:09 dme65

jit_compile=False

Where would I pass this as an argument in the example notebook on Colab? Sorry for a potentially dumb question...

winf-hsos avatar Sep 08 '22 07:09 winf-hsos

    nuts = NUTS(
        sample,
        jit_compile=False, #  <========
        full_mass=True,  # not sure why this is on in the notebook; 
                         # this isn't necessarily expected to work well in the higher dimensional case
        ignore_jit_warnings=True,
        max_tree_depth=max_tree_depth,
    )

martinjankowiak avatar Sep 08 '22 11:09 martinjankowiak

Issue Description

When running NUTS on a pretty basic model using 1.12.1+cu113, the following exception is raised (for a full repro see the colab link below):

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-f2613a8adf9b> in <module>
     21     disable_progbar=disable_progbar,
     22 )
---> 23 mcmc.run()

12 frames
/usr/local/lib/python3.7/dist-packages/pyro/infer/mcmc/util.py in _potential_fn_jit(self, skip_jit_warnings, jit_options, params)
    292 
    293         if self._compiled_fn:
--> 294             return self._compiled_fn(*vals)
    295 
    296         with pyro.validation_enabled(False):

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Graph::copy() encountered a use of a value 133 not in scope. Run lint!

This seems to be an issue only on pytorch 1.12.1; when downgrading pytorch to 1.11 this error disappears.

Note: This issue was originally raised in facebook/Ax#1108, and there is some more discussion there. This issue has a minimal repro reduced to the pyro model.

Environment

For any bugs, please provide the following:

  • google colab (though reproducible on other environments)
  • 1.12.1+cu113
  • Pyro version: output of 1.8.2

Code Snippet

https://colab.research.google.com/drive/1mRY01XShwIb06aNj-UgKRd7C4av1jqTU (make sure to run on GPU runtime)

Actually it also works on one of my env with pytorch=1.12.0, cuda=11.0, GTX 1080Ti, so it could just be the 1.12.1 update (or something else).

Also it doesn't seem that disabling jit makes the code much slower.

xingchenwan avatar Sep 08 '22 11:09 xingchenwan

this isn't necessarily expected to work well in the higher dimensional case

I was wondering about this: is this primarily because of the cost of inverting, Cholesky, or something else?

feynmanliang avatar Sep 20 '22 16:09 feynmanliang

this isn't necessarily expected to work well in the higher dimensional case

I was wondering about this: is this primarily because of the cost of inverting, Cholesky, or something else?

Martin's point is that it's hard to estimate the D x D mass matrix from a small number of datapoints if D is large.

dme65 avatar Sep 20 '22 16:09 dme65

@feynmanliang in general this is an issue that is both about computation (extra linear algebra) and statistical effieciency (ability to estimate a covariance matrix from finite samples).

in this particular case computing the potential energy is already O(N^3) so the additional linear algebra costs are probably moderate unless D is very large. the bigger problem is statistical estimation. even if i gave you 100 perfect posterior samples your estimate of a 300 x 300 covariance matrix will be noisy. in practice we get approximate posterior samples during warmup (they're also correlated with one another) and there are feedback loops between the adapted mass matrix and the samples obtained in warm-up. so getting good mass matrix estimates is expected to be hard and would generally require very long warm-up periods

martinjankowiak avatar Sep 20 '22 18:09 martinjankowiak