RuntimeError: Graph::copy() with pyro on pytorch 1.12.1 (GPU)
Issue Description
When running NUTS on a pretty basic model using 1.12.1+cu113, the following exception is raised (for a full repro see the colab link below):
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-9-f2613a8adf9b> in <module>
21 disable_progbar=disable_progbar,
22 )
---> 23 mcmc.run()
12 frames
/usr/local/lib/python3.7/dist-packages/pyro/infer/mcmc/util.py in _potential_fn_jit(self, skip_jit_warnings, jit_options, params)
292
293 if self._compiled_fn:
--> 294 return self._compiled_fn(*vals)
295
296 with pyro.validation_enabled(False):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Graph::copy() encountered a use of a value 133 not in scope. Run lint!
This seems to be an issue only on pytorch 1.12.1; when downgrading pytorch to 1.11 this error disappears.
Note: This issue was originally raised in https://github.com/facebook/Ax/issues/1108, and there is some more discussion there. This issue has a minimal repro reduced to the pyro model.
Environment
For any bugs, please provide the following:
- google colab (though reproducible on other environments)
1.12.1+cu113- Pyro version: output of
1.8.2
Code Snippet
https://colab.research.google.com/drive/1mRY01XShwIb06aNj-UgKRd7C4av1jqTU (make sure to run on GPU runtime)
@Balandat as you're aware, this is almost certainly a torch jit issue. unfortunately we find torch fit to be quite brittle from release to release. apart from downgrading torch, i expect passing jit_compile=False would also fix things. i don't know how much speed-up jit_compile=True gives you anyway. it's been many release cycles since we've seen torch jit deliver larger speed-ups
I can confirm that jit_compile=False fixes things - I'll look into the what the speedup from this is.
cc @dme65
I can confirm that
jit_compile=Falsefixes things - I'll look into the what the speedup from this is.
Disabling jit worked for @xingchenwan as well who was running into the same issue. I recall jit giving pretty decent speed-ups for a very small number of datapoints, but that was also over a year ago so it makes sense to take another look and see if jit makes a difference on the latest versions.
jit_compile=False
Where would I pass this as an argument in the example notebook on Colab? Sorry for a potentially dumb question...
nuts = NUTS(
sample,
jit_compile=False, # <========
full_mass=True, # not sure why this is on in the notebook;
# this isn't necessarily expected to work well in the higher dimensional case
ignore_jit_warnings=True,
max_tree_depth=max_tree_depth,
)
Issue Description
When running NUTS on a pretty basic model using
1.12.1+cu113, the following exception is raised (for a full repro see the colab link below):--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-9-f2613a8adf9b> in <module> 21 disable_progbar=disable_progbar, 22 ) ---> 23 mcmc.run() 12 frames /usr/local/lib/python3.7/dist-packages/pyro/infer/mcmc/util.py in _potential_fn_jit(self, skip_jit_warnings, jit_options, params) 292 293 if self._compiled_fn: --> 294 return self._compiled_fn(*vals) 295 296 with pyro.validation_enabled(False): RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): RuntimeError: Graph::copy() encountered a use of a value 133 not in scope. Run lint!This seems to be an issue only on pytorch 1.12.1; when downgrading pytorch to 1.11 this error disappears.
Note: This issue was originally raised in facebook/Ax#1108, and there is some more discussion there. This issue has a minimal repro reduced to the pyro model.
Environment
For any bugs, please provide the following:
- google colab (though reproducible on other environments)
1.12.1+cu113- Pyro version: output of
1.8.2Code Snippet
https://colab.research.google.com/drive/1mRY01XShwIb06aNj-UgKRd7C4av1jqTU (make sure to run on GPU runtime)
Actually it also works on one of my env with pytorch=1.12.0, cuda=11.0, GTX 1080Ti, so it could just be the 1.12.1 update (or something else).
Also it doesn't seem that disabling jit makes the code much slower.
this isn't necessarily expected to work well in the higher dimensional case
I was wondering about this: is this primarily because of the cost of inverting, Cholesky, or something else?
this isn't necessarily expected to work well in the higher dimensional case
I was wondering about this: is this primarily because of the cost of inverting, Cholesky, or something else?
Martin's point is that it's hard to estimate the D x D mass matrix from a small number of datapoints if D is large.
@feynmanliang in general this is an issue that is both about computation (extra linear algebra) and statistical effieciency (ability to estimate a covariance matrix from finite samples).
in this particular case computing the potential energy is already O(N^3) so the additional linear algebra costs are probably moderate unless D is very large. the bigger problem is statistical estimation. even if i gave you 100 perfect posterior samples your estimate of a 300 x 300 covariance matrix will be noisy. in practice we get approximate posterior samples during warmup (they're also correlated with one another) and there are feedback loops between the adapted mass matrix and the samples obtained in warm-up. so getting good mass matrix estimates is expected to be hard and would generally require very long warm-up periods