pyro icon indicating copy to clipboard operation
pyro copied to clipboard

PyTorch 1.10 throws new jit errors

Open fritzo opened this issue 4 years ago • 1 comments

This issue tracks new errors in the PyTorch 1.10 jit, e.g.

pytest tests/infer/test_jit.py::test_dirichlet_bernoulli -k Jit -vx --runxfail
__ test_dirichlet_bernoulli[JitTraceEnum_ELBO-False] __

Elbo = <class 'pyro.infer.traceenum_elbo.JitTraceEnum_ELBO'>, vectorized = False

    @pytest.mark.parametrize("vectorized", [False, True])
    @pytest.mark.parametrize(
        "Elbo",
        [
            TraceEnum_ELBO,
            JitTraceEnum_ELBO,
        ],
    )
    def test_dirichlet_bernoulli(Elbo, vectorized):
        pyro.clear_param_store()
        data = torch.tensor([1.0] * 6 + [0.0] * 4)

        def model1(data):
            concentration0 = constant([10.0, 10.0])
            f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1]
            for i in pyro.plate("plate", len(data)):
                pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

        def model2(data):
            concentration0 = constant([10.0, 10.0])
            f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1]
            pyro.sample(
                "obs", dist.Bernoulli(f).expand_by(data.shape).to_event(1), obs=data
            )

        model = model2 if vectorized else model1

        def guide(data):
            concentration_q = pyro.param(
                "concentration_q", constant([15.0, 15.0]), constraint=constraints.positive
            )
            pyro.sample("latent_fairness", dist.Dirichlet(concentration_q))

        elbo = Elbo(
            num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True
        )
        optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)})
        svi = SVI(model, guide, optim, elbo)
        for step in range(40):
>           svi.step(data)

tests/infer/test_jit.py:462:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyro/infer/svi.py:145: in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
pyro/infer/traceenum_elbo.py:564: in loss_and_grads
    differentiable_loss = self.differentiable_loss(model, guide, *args, **kwargs)
pyro/infer/traceenum_elbo.py:561: in differentiable_loss
    return self._differentiable_loss(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <pyro.ops.jit.CompiledFunction object at 0x7ff5225e9400>
args = (tensor([1., 1., 1., 1., 1., 1., 0., 0., 0., 0.]),)
kwargs = {'_guide_id': 140690819139104, '_model_id': 140690812334288}
key = (1, (('_guide_id', 140690819139104), ('_model_id', 140690812334288)))
unconstrained_params = [tensor([2.7072, 2.7090], requires_grad=True)]
params_and_args = [tensor([2.7072, 2.7090], requires_grad=True), tensor([1., 1., 1., 1., 1., 1., 0., 0., 0., 0.])]
param_capture = <pyro.poutine.trace_messenger.TraceMessenger object at 0x7ff5225c4898>

    def __call__(self, *args, **kwargs):
        key = _hashable_args_kwargs(args, kwargs)

        # if first time
        if key not in self.compiled:
            # param capture
            with poutine.block():
                with poutine.trace(param_only=True) as first_param_capture:
                    self.fn(*args, **kwargs)

            self._param_names = list(set(first_param_capture.trace.nodes.keys()))
            unconstrained_params = tuple(
                pyro.param(name).unconstrained() for name in self._param_names
            )
            params_and_args = unconstrained_params + args
            weakself = weakref.ref(self)

            def compiled(*params_and_args):
                self = weakself()
                unconstrained_params = params_and_args[: len(self._param_names)]
                args = params_and_args[len(self._param_names) :]
                constrained_params = {}
                for name, unconstrained_param in zip(
                    self._param_names, unconstrained_params
                ):
                    constrained_param = pyro.param(
                        name
                    )  # assume param has been initialized
                    assert constrained_param.unconstrained() is unconstrained_param
                    constrained_params[name] = constrained_param
                return poutine.replay(self.fn, params=constrained_params)(
                    *args, **kwargs
                )

            if self.ignore_warnings:
                compiled = ignore_jit_warnings()(compiled)
            with pyro.validation_enabled(False):
                time_compilation = self.jit_options.pop("time_compilation", False)
                with optional(timed(), time_compilation) as t:
                    self.compiled[key] = torch.jit.trace(
                        compiled, params_and_args, **self.jit_options
                    )
                if time_compilation:
                    self.compile_time = t.elapsed
        else:
            unconstrained_params = [
                # FIXME this does unnecessary transform work
                pyro.param(name).unconstrained() for name in self._param_names
            ]
            params_and_args = unconstrained_params + list(args)

        with poutine.block(hide=self._param_names):
            with poutine.trace(param_only=True) as param_capture:
>               ret = self.compiled[key](*params_and_args)
E               RuntimeError: The following operation failed in the TorchScript interpreter.
E               Traceback of TorchScript (most recent call last):
E               RuntimeError: Unsupported value kind: Tensor

pyro/ops/jit.py:121: RuntimeError

It looks like some inserted constant tensor is failing insertableTensor by requiring grad. I've spent a couple hours debugging but haven't been able to isolate the error.

fritzo avatar Nov 09 '21 15:11 fritzo

Can reproduce this with a completely different model and it only happens on PyTorch 1.10.

For me it also only happens when I torch.jit.freeze the model.

Linux-cpp-lisp avatar Nov 18 '21 20:11 Linux-cpp-lisp