gpytorch
gpytorch copied to clipboard
Bayesian GPs with pyro (NUTS) - example notebook crashes when jit_compile=True
I'm trying to run your fully Bayesian GP example.
The notebook runs OK as-is. As you may expect, sampling is much slower when i increase the size of the training dataset. I've tried to enable jit
compilation in the pyro
NUTS sampler:
nuts_kernel = NUTS(pyro_model, adapt_step_size=True, jit_compile=True)
After this change, the NUTS sampler crashes:
Warmup: 0%| | 0/300 [00:00, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-4-56cc00113944> in <module>()
26 nuts_kernel = NUTS(pyro_model, adapt_step_size=True, jit_compile=True)
27 mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, disable_progbar=smoke_test)
---> 28 mcmc_run.run(train_x, train_y)
24 frames
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
378 with optional(pyro.validation_enabled(not self.disable_validation),
379 self.disable_validation is not None):
--> 380 for x, chain_id in self.sampler.run(*args, **kwargs):
381 if num_samples[chain_id] == 0:
382 num_samples[chain_id] += 1
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
167 for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging,
168 i if self.num_chains > 1 else None,
--> 169 *args, **kwargs):
170 yield sample, i # sample, chain_id
171 self.kernel.cleanup()
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
109
110 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 111 kernel.setup(warmup_steps, *args, **kwargs)
112 params = kernel.initial_params
113 # yield structure (key, value.shape) of params
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
304 if self.initial_params:
305 z = {k: v.detach() for k, v in self.initial_params.items()}
--> 306 z_grads, potential_energy = potential_grad(self.potential_fn, z)
307 else:
308 z_grads, potential_energy = {}, self.potential_fn(self.initial_params)
/usr/local/lib/python3.6/dist-packages/pyro/ops/integrator.py in potential_grad(potential_fn, z)
80 return grads, z_nodes[0].new_tensor(float('nan'))
81 else:
---> 82 raise e
83
84 grads = grad(potential_energy, z_nodes)
/usr/local/lib/python3.6/dist-packages/pyro/ops/integrator.py in potential_grad(potential_fn, z)
73 node.requires_grad_(True)
74 try:
---> 75 potential_energy = potential_fn(z)
76 # deal with singular matrices
77 except RuntimeError as e:
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _potential_fn_jit(self, skip_jit_warnings, jit_options, params)
287 if skip_jit_warnings:
288 _pe_jit = ignore_jit_warnings()(_pe_jit)
--> 289 self._compiled_fn = torch.jit.trace(_pe_jit, vals, **jit_options)
290
291 result = self._compiled_fn(*vals)
/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
978 var_lookup_fn,
979 strict,
--> 980 _force_outplace)
981
982 # Check the trace against new traces created from user-specified inputs
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _pe_jit(*zi)
283 def _pe_jit(*zi):
284 params = dict(zip(names, zi))
--> 285 return self._potential_fn(params)
286
287 if skip_jit_warnings:
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _potential_fn(self, params)
259 cond_model = poutine.condition(self.model, params_constrained)
260 model_trace = poutine.trace(cond_model).get_trace(*self.model_args,
--> 261 **self.model_kwargs)
262 log_joint = self.trace_prob_evaluator.log_prob(model_trace)
263 for name, t in self.transforms.items():
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
185 Calls this poutine and returns its trace instead of the function's return value.
186 """
--> 187 self(*args, **kwargs)
188 return self.msngr.get_trace()
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
169 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
170 exc = exc.with_traceback(traceback)
--> 171 raise exc from None
172 self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
173 return ret
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
163 args=args, kwargs=kwargs)
164 try:
--> 165 ret = self.fn(*args, **kwargs)
166 except (ValueError, RuntimeError):
167 exc_type, exc_value, traceback = sys.exc_info()
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
<ipython-input-4-56cc00113944> in pyro_model(x, y)
19
20 def pyro_model(x, y):
---> 21 model.pyro_sample_from_prior()
22 output = model(x)
23 loss = mll.pyro_factor(output, y)
/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in pyro_sample_from_prior(self)
318 parameters of the model that have GPyTorch priors registered to them.
319 """
--> 320 return _pyro_sample_from_prior(module=self, memo=None, prefix="")
321
322 def local_load_samples(self, samples_dict, memo, prefix):
/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in _pyro_sample_from_prior(module, memo, prefix)
427 for mname, module_ in module.named_children():
428 submodule_prefix = prefix + ("." if prefix else "") + mname
--> 429 _pyro_sample_from_prior(module=module_, memo=memo, prefix=submodule_prefix)
430
431
/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in _pyro_sample_from_prior(module, memo, prefix)
421 )
422 memo.add(prior)
--> 423 prior = prior.expand(closure().shape)
424 value = pyro.sample(prefix + ("." if prefix else "") + prior_name, prior)
425 setting_closure(value)
/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in closure()
226
227 def closure():
--> 228 return getattr(self, param_or_closure)
229
230 if setting_closure is not None:
/usr/local/lib/python3.6/dist-packages/gpytorch/likelihoods/gaussian_likelihood.py in noise(self)
83 @property
84 def noise(self) -> Tensor:
---> 85 return self.noise_covar.noise
86
87 @noise.setter
/usr/local/lib/python3.6/dist-packages/gpytorch/likelihoods/noise_models.py in noise(self)
33 @property
34 def noise(self):
---> 35 return self.raw_noise_constraint.transform(self.raw_noise)
36
37 @noise.setter
/usr/local/lib/python3.6/dist-packages/gpytorch/constraints/constraints.py in transform(self, tensor)
174
175 def transform(self, tensor):
--> 176 transformed_tensor = self._transform(tensor) if self.enforced else tensor
177 return transformed_tensor
178
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
-1.2059
[ torch.FloatTensor{1} ]
Trace Shapes:
Param Sites:
Sample Sites:
I've done some googling and found https://github.com/pyro-ppl/pyro/issues/2292 - this seems to indicate that i failed to properly register a prior, perhaps for the noise_covar.noise
of my Gaussian likelihood? Is this true? In your example, I do see a noise prior being registered, namely
likelihood.register_prior("noise_prior", UniformPrior(0.05, 0.3), "noise")
If so, how do I register the missing prior? Or am I looking at this the wrong way? Thanks!
To be honest, I'm not totally sure what's going on here because I'm not familiar enough with the exact internals of Pyro.
This might have something to do with the fact that we register priors to transformed versions of the parameters, rather than the parameters directly?
we register priors to transformed versions of the parameters, rather than the parameters directly
Hmm... it seems that the priors we register in this example are to the original (untransformed / constrained) versions of the parameters, right? E.g. we require the lengthscale
to be positive by assigning it a uniform prior over [0.01, 0.5]
. I'm guessing Pyro maps these to an unconstrained space pre-inference, right?
model.mean_module.register_prior("mean_prior", UniformPrior(-1, 1), "constant")
model.covar_module.base_kernel.register_prior("lengthscale_prior", UniformPrior(0.01, 0.5), "lengthscale")
model.covar_module.base_kernel.register_prior("period_length_prior", UniformPrior(0.05, 2.5), "period_length")
model.covar_module.register_prior("outputscale_prior", UniformPrior(1, 2), "outputscale")
likelihood.register_prior("noise_prior", UniformPrior(0.05, 0.3), "noise")
Lengthscale, period_length, output scale, and noise are all "transformed" parameters in GPyTorch. We store an unconstrained version of these parameters (e.g. referred to as _lengthscale
) and then transform it (usually via the soft plus function) to the positive-valued parameter (e.g. referred to as lengthscale
).
The priors are in some sense an additional constraint, but they are being applied to the transformed parameter (e.g. length scale
) rather than the untransformed parameter (e.g. _lengthscale
).
@neerajprad @fehiepsi any idea what might be going on here? should be easy to reproduce (?) using the notebook linked in the issue.
@mihai-spire incidentally if you're interested in speed you might try using jax + numpyro. although there's no built-in support for GPs so it won't be quite as plug-and-play
The issue is that somewhere during JIT tracing, we are inserting a tensor (most likely a constant) with requires_grad=True
, whereas torch.jit
expects all such tensors to be arguments instead. As @mihai-spire pointed out in the Pyro issue, this can happen due to a code path that caches tensors with requires_grad=True
on the first invocation and inserts it later during tracing. Maybe there is some caching that happens within transforms or even earlier (the backtrace should provide some clue)? @jacobrgardner can probably speak to that. If that is indeed the issue, one solution would be to provide an option that disables caching during JIT tracing.
Hi, I have encountered this same issue. Have you figured one way out?
+1. It would be great to get to get some more clarity on the issue here. Using jit should greatly speed things up
This issue is basically superseded by #1578, as the bugs are caused by the same problem, and a fix to #1578 will resolve this issue as well. I've confirmed that jit_compile works with at least the fix I have so far.
I get the following error when I run the notebook with jit_compile=True
~/opt/anaconda3/envs/dev/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in _potential_fn_jit(self, skip_jit_warnings, jit_options, params)
292
293 if self._compiled_fn:
--> 294 return self._compiled_fn(*vals)
295
296 with pyro.validation_enabled(False):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Unsupported value kind: Tensor
The example works with jit_compile=False
.