pyro
pyro copied to clipboard
PyTorch 1.10 throws new jit errors
This issue tracks new errors in the PyTorch 1.10 jit, e.g.
pytest tests/infer/test_jit.py::test_dirichlet_bernoulli -k Jit -vx --runxfail
__ test_dirichlet_bernoulli[JitTraceEnum_ELBO-False] __
Elbo = <class 'pyro.infer.traceenum_elbo.JitTraceEnum_ELBO'>, vectorized = False
@pytest.mark.parametrize("vectorized", [False, True])
@pytest.mark.parametrize(
"Elbo",
[
TraceEnum_ELBO,
JitTraceEnum_ELBO,
],
)
def test_dirichlet_bernoulli(Elbo, vectorized):
pyro.clear_param_store()
data = torch.tensor([1.0] * 6 + [0.0] * 4)
def model1(data):
concentration0 = constant([10.0, 10.0])
f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1]
for i in pyro.plate("plate", len(data)):
pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])
def model2(data):
concentration0 = constant([10.0, 10.0])
f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1]
pyro.sample(
"obs", dist.Bernoulli(f).expand_by(data.shape).to_event(1), obs=data
)
model = model2 if vectorized else model1
def guide(data):
concentration_q = pyro.param(
"concentration_q", constant([15.0, 15.0]), constraint=constraints.positive
)
pyro.sample("latent_fairness", dist.Dirichlet(concentration_q))
elbo = Elbo(
num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True
)
optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)})
svi = SVI(model, guide, optim, elbo)
for step in range(40):
> svi.step(data)
tests/infer/test_jit.py:462:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyro/infer/svi.py:145: in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
pyro/infer/traceenum_elbo.py:564: in loss_and_grads
differentiable_loss = self.differentiable_loss(model, guide, *args, **kwargs)
pyro/infer/traceenum_elbo.py:561: in differentiable_loss
return self._differentiable_loss(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <pyro.ops.jit.CompiledFunction object at 0x7ff5225e9400>
args = (tensor([1., 1., 1., 1., 1., 1., 0., 0., 0., 0.]),)
kwargs = {'_guide_id': 140690819139104, '_model_id': 140690812334288}
key = (1, (('_guide_id', 140690819139104), ('_model_id', 140690812334288)))
unconstrained_params = [tensor([2.7072, 2.7090], requires_grad=True)]
params_and_args = [tensor([2.7072, 2.7090], requires_grad=True), tensor([1., 1., 1., 1., 1., 1., 0., 0., 0., 0.])]
param_capture = <pyro.poutine.trace_messenger.TraceMessenger object at 0x7ff5225c4898>
def __call__(self, *args, **kwargs):
key = _hashable_args_kwargs(args, kwargs)
# if first time
if key not in self.compiled:
# param capture
with poutine.block():
with poutine.trace(param_only=True) as first_param_capture:
self.fn(*args, **kwargs)
self._param_names = list(set(first_param_capture.trace.nodes.keys()))
unconstrained_params = tuple(
pyro.param(name).unconstrained() for name in self._param_names
)
params_and_args = unconstrained_params + args
weakself = weakref.ref(self)
def compiled(*params_and_args):
self = weakself()
unconstrained_params = params_and_args[: len(self._param_names)]
args = params_and_args[len(self._param_names) :]
constrained_params = {}
for name, unconstrained_param in zip(
self._param_names, unconstrained_params
):
constrained_param = pyro.param(
name
) # assume param has been initialized
assert constrained_param.unconstrained() is unconstrained_param
constrained_params[name] = constrained_param
return poutine.replay(self.fn, params=constrained_params)(
*args, **kwargs
)
if self.ignore_warnings:
compiled = ignore_jit_warnings()(compiled)
with pyro.validation_enabled(False):
time_compilation = self.jit_options.pop("time_compilation", False)
with optional(timed(), time_compilation) as t:
self.compiled[key] = torch.jit.trace(
compiled, params_and_args, **self.jit_options
)
if time_compilation:
self.compile_time = t.elapsed
else:
unconstrained_params = [
# FIXME this does unnecessary transform work
pyro.param(name).unconstrained() for name in self._param_names
]
params_and_args = unconstrained_params + list(args)
with poutine.block(hide=self._param_names):
with poutine.trace(param_only=True) as param_capture:
> ret = self.compiled[key](*params_and_args)
E RuntimeError: The following operation failed in the TorchScript interpreter.
E Traceback of TorchScript (most recent call last):
E RuntimeError: Unsupported value kind: Tensor
pyro/ops/jit.py:121: RuntimeError
It looks like some inserted constant tensor is failing insertableTensor by requiring grad. I've spent a couple hours debugging but haven't been able to isolate the error.
Can reproduce this with a completely different model and it only happens on PyTorch 1.10.
For me it also only happens when I torch.jit.freeze the model.