gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

[Bug] JIT compile for ExactGP fails with TracingCheckError for 500+ observations

Open SebastianCallh opened this issue 3 years ago • 8 comments

🐛 Bug

Running the example in TorchScript_Exact_Models.ipynb with more than ~500 observations causes the JIT compilation to fail with torch.jit._trace.TracingCheckError: Tracing failed sanity checks! ERROR: Graphs differed across invocations!.

To reproduce

Run the code in TorchScript_Exact_Models.ipynb but change train_x = torch.linspace(0, 1, 100) to train_x = torch.linspace(0, 1, 500). Alternatively run the following smaller example.

import torch
import gpytorch


class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


train_x = torch.linspace(0, 1, 100)
train_y = torch.sin(train_x)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)

model.train()
likelihood.train()


class MeanVarModelWrapper(torch.nn.Module):
    def __init__(self, gp):
        super().__init__()
        self.gp = gp

    def forward(self, x):
        output_dist = self.gp(x)
        return output_dist.mean, output_dist.variance


with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.trace_mode():
    model.eval()
    test_x = torch.linspace(0, 1, 500)
    pred = model(test_x)  # Do precomputation
    traced_model = torch.jit.trace(MeanVarModelWrapper(model), test_x)

Stack trace/error message
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/likelihoods/likelihood.py:321: RuntimeWarning: name_prefix is only used for likehoods that are integrated with Pyro.
 warnings.warn("name_prefix is only used for likehoods that are integrated with Pyro.", RuntimeWarning)
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/likelihoods/likelihood.py:312: RuntimeWarning: num_data is only used for likehoods that are integrated with Pyro.
 warnings.warn("num_data is only used for likehoods that are integrated with Pyro.", RuntimeWarning)
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:253: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
 if joint_covar.size(-1) <= settings.max_eager_kernel_size.value():
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/utils/broadcasting.py:43: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
 if n != shape_b[-2]:
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:371: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
 if not x1_.size(-1) == x2_.size(-1):
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:303: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
 postprocess = torch.tensor(postprocess)
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:50: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
 return self._postprocess(res) if postprocess else res
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:289: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
 if res.shape != self.shape:
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/utils/broadcasting.py:16: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
 non_singleton_sizes = tuple(size for size in size_by_dim if size != 1)
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/utils/broadcasting.py:18: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
 if any(size != non_singleton_sizes[0] for size in non_singleton_sizes):
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:315: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
 if postprocess:
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:262: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
 if res.shape != expected_shape:
/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/distributions/multivariate_normal.py:258: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
 if variance.lt(min_variance).any():
Traceback (most recent call last):
 File "/home/sebastian/nv/ml-cycle-life/modelling/scripts/gp_trace_example.py", line 40, in <module>
   traced_model = torch.jit.trace(MeanVarModelWrapper(model), test_x)
 File "/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/jit/_trace.py", line 741, in trace
   return trace_module(
 File "/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/jit/_trace.py", line 983, in trace_module    _check_trace(
 File "/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
   return func(*args, **kwargs)
 File "/home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/jit/_trace.py", line 526, in _check_trace    raise TracingCheckError(*diag_info)
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
       Graph diff:
                 graph(%self.1 : __torch__.MeanVarModelWrapper,
                       %x : Tensor):
                   %gp : __torch__.ExactGPModel = prim::GetAttr[name="gp"](%self.1)
                   %covar_module : __torch__.gpytorch.kernels.scale_kernel.ScaleKernel = prim::GetAttr[name="covar_module"](%gp)
                   %raw_outputscale : Tensor = prim::GetAttr[name="raw_outputscale"](%covar_module)
                   %gp.5 : __torch__.ExactGPModel = prim::GetAttr[name="gp"](%self.1)
                   %covar_module.3 : __torch__.gpytorch.kernels.scale_kernel.ScaleKernel = prim::GetAttr[name="covar_module"](%gp.5)
                   %base_kernel : __torch__.gpytorch.kernels.rbf_kernel.RBFKernel = prim::GetAttr[name="base_kernel"](%covar_module.3)
                   %raw_lengthscale : Tensor = prim::GetAttr[name="raw_lengthscale"](%base_kernel)
                   %gp.3 : __torch__.ExactGPModel = prim::GetAttr[name="gp"](%self.1)
                   %covar_module.1 : __torch__.gpytorch.kernels.scale_kernel.ScaleKernel = prim::GetAttr[name="covar_module"](%gp.3)
                   %raw_outputscale_constraint : __torch__.gpytorch.constraints.constraints.Positive = prim::GetAttr[name="raw_outputscale_constraint"](%covar_module.1)
                   %_transform : __torch__.torch.nn.modules.activation.Softplus = prim::GetAttr[name="_transform"](%raw_outputscale_constraint)
                   %gp.1 : __torch__.ExactGPModel = prim::GetAttr[name="gp"](%self.1)
                   %mean_module : __torch__.gpytorch.means.constant_mean.ConstantMean = prim::GetAttr[name="mean_module"](%gp.1)
                   %constant : Tensor = prim::GetAttr[name="constant"](%mean_module)
                   %16 : int = prim::Constant[value=-1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_gp.py:245:0
                   %input.1 : Tensor = aten::unsqueeze(%x, %16) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_gp.py:245:0
                   %train_input : Tensor = prim::Constant[value=<Tensor>]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_gp.py:292:0
                   %19 : Tensor[] = prim::ListConstruct(%train_input, %input.1)
                   %20 : int = prim::Constant[value=-2]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_gp.py:302:0
                   %input.3 : Tensor = aten::cat(%19, %20) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_gp.py:302:0
                   %22 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/means/constant_mean.py:19:0
                   %23 : int = aten::size(%input.3, %22) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/means/constant_mean.py:19:0
                   %24 : Tensor = prim::NumToTensor(%23)
                   %25 : int = aten::Int(%24)
                   %26 : int[] = prim::ListConstruct(%25)
                   %27 : bool = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/means/constant_mean.py:19:0
                   %mean.1 : Tensor = aten::expand(%constant, %26, %27) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/means/constant_mean.py:19:0
                   %29 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/distributions/multivariate_normal.py:46:0
                   %30 : int = aten::size(%mean.1, %29) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/distributions/multivariate_normal.py:46:0
                   %31 : Tensor = prim::NumToTensor(%30)
                   %32 : Tensor = prim::Constant[value={100}]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_gp.py:315:0
                   %33 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_gp.py:315:0
                   %34 : Tensor = aten::sub(%31, %32, %33) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_gp.py:315:0
                   %35 : int = aten::Int(%34)
                   %36 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:251:0
                   %37 : int = prim::Constant[value=100]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:251:0
                   %38 : int = prim::Constant[value=9223372036854775807]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:251:0
                   %39 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:251:0
                   %test_mean : Tensor = aten::slice(%mean.1, %36, %37, %38, %39) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:251:0
                   %41 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %42 : int = prim::Constant[value=100]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %43 : int = prim::Constant[value=9223372036854775807]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %44 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %45 : Tensor = aten::slice(%input.3, %41, %42, %43, %44) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %46 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %47 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %48 : int = prim::Constant[value=9223372036854775807]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %49 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %x1.7 : Tensor = aten::slice(%45, %46, %47, %48, %49) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %51 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %52 : int = prim::Constant[value=100]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %53 : int = prim::Constant[value=9223372036854775807]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %54 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %55 : Tensor = aten::slice(%input.3, %51, %52, %53, %54) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %56 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %57 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %58 : int = prim::Constant[value=9223372036854775807]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %59 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %x1.1 : Tensor = aten::slice(%55, %56, %57, %58, %59) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:98:0
                   %61 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:112:0
                   %62 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:112:0
                   %63 : int = prim::Constant[value=100]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:112:0
                   %64 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:112:0
                   %65 : Tensor = aten::slice(%input.3, %61, %62, %63, %64) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:112:0
                   %66 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:112:0
                   %67 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:112:0
                   %68 : int = prim::Constant[value=9223372036854775807]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:112:0
                   %69 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:112:0
                   %x2.1 : Tensor = aten::slice(%65, %66, %67, %68, %69) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:112:0
                   %71 : Tensor = prim::Constant[value=<Tensor>]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:280:0
                   %72 : int = prim::Constant[value=-1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:280:0
                   %other.1 : Tensor = aten::unsqueeze(%71, %72) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:280:0
                   %211 : int = prim::Constant[value=20](), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %212 : int = prim::Constant[value=1](), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %213 : Tensor = aten::softplus(%raw_lengthscale, %212, %211), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %x1.3 : Tensor = aten::div(%x1.1, %213) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/rbf_kernel.py:81:0
                   %214 : int = prim::Constant[value=20](), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %215 : int = prim::Constant[value=1](), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %216 : Tensor = aten::softplus(%raw_lengthscale, %215, %214), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %x2.3 : Tensor = aten::div(%x2.1, %216) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/rbf_kernel.py:82:0
                   %78 : int = prim::Constant[value=-2]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:29:0
                   %79 : int[] = prim::ListConstruct(%78)
                   %80 : bool = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:29:0
                   %81 : NoneType = prim::Constant()
                   %adjustment : Tensor = aten::mean(%x1.3, %79, %80, %81) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:29:0
                   %83 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:30:0
                   %x1.5 : Tensor = aten::sub(%x1.3, %adjustment, %83) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:30:0
                   %85 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:31:0
                   %x2.5 : Tensor = aten::sub(%x2.3, %adjustment, %85) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:31:0
                   %87 : int = prim::Constant[value=2]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:34:0
                   %88 : Tensor = aten::pow(%x1.5, %87) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:34:0
                   %89 : int = prim::Constant[value=-1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:34:0
                   %90 : int[] = prim::ListConstruct(%89)
                   %91 : bool = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:34:0
                   %92 : NoneType = prim::Constant()
                   %x1_norm : Tensor = aten::sum(%88, %90, %91, %92) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:34:0
                   %94 : int = prim::Constant[value=6]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:35:0
                   %95 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:35:0
                   %96 : Device = prim::Constant[value="cpu"]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:35:0
                   %97 : bool = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:35:0
                   %98 : NoneType = prim::Constant()
                   %x1_pad : Tensor = aten::ones_like(%x1_norm, %94, %95, %96, %97, %98) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:35:0
                   %100 : int = prim::Constant[value=2]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:39:0
                   %101 : Tensor = aten::pow(%x2.5, %100) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:39:0
                   %102 : int = prim::Constant[value=-1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:39:0
                   %103 : int[] = prim::ListConstruct(%102)
                   %104 : bool = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:39:0
                   %105 : NoneType = prim::Constant()
                   %x2_norm : Tensor = aten::sum(%101, %103, %104, %105) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:39:0
                   %107 : int = prim::Constant[value=6]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:40:0
                   %108 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:40:0
                   %109 : Device = prim::Constant[value="cpu"]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:40:0
                   %110 : bool = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:40:0
                   %111 : NoneType = prim::Constant()
                   %x2_pad : Tensor = aten::ones_like(%x2_norm, %107, %108, %109, %110, %111) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:40:0
                   %113 : Tensor = prim::Constant[value={-2}]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:41:0
                   %114 : Tensor = aten::mul(%x1.5, %113) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:41:0
                   %115 : Tensor[] = prim::ListConstruct(%114, %x1_norm, %x1_pad)
                   %116 : int = prim::Constant[value=-1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:41:0
                   %x1_ : Tensor = aten::cat(%115, %116) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:41:0
                   %118 : Tensor[] = prim::ListConstruct(%x2.5, %x2_pad, %x2_norm)
                   %119 : int = prim::Constant[value=-1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:42:0
                   %x2_ : Tensor = aten::cat(%118, %119) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:42:0
                   %121 : int = prim::Constant[value=-2]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:43:0
                   %122 : int = prim::Constant[value=-1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:43:0
                   %123 : Tensor = aten::transpose(%x2_, %121, %122) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:43:0
                   %res.1 : Tensor = aten::matmul(%x1_, %123) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:43:0
                   %125 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:49:0
                   %dist_mat.1 : Tensor = aten::clamp_min_(%res.1, %125) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:49:0
                   %127 : Tensor = prim::Constant[value={-2}]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/rbf_kernel.py:9:0
                   %dist_mat.3 : Tensor = aten::div_(%dist_mat.1, %127) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/rbf_kernel.py:9:0
                   %orig_output.1 : Tensor = aten::exp_(%dist_mat.3) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/rbf_kernel.py:9:0
                   %217 : int = prim::Constant[value=20](), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %218 : int = prim::Constant[value=1](), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %outputscales.1 : Tensor = aten::softplus(%raw_outputscale, %218, %217), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %131 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/scale_kernel.py:100:0
                   %132 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/scale_kernel.py:100:0
                   %133 : int[] = prim::ListConstruct(%131, %132)
                   %outputscales.3 : Tensor = aten::view(%outputscales.1, %133) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/scale_kernel.py:100:0
                   %135 : Tensor = aten::mul(%orig_output.1, %outputscales.3) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/scale_kernel.py:101:0
               -   %136 : Tensor = ^Matmul(<gpytorch.lazy.lazy_tensor_representation_tree.LazyTensorRepresentationTree object at 0x7f0e23f9d790>)(%other.1, %135) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_tensor.py:1337:0
               ?                                                                                                                        ^^^^^^
               +   %136 : Tensor = ^Matmul(<gpytorch.lazy.lazy_tensor_representation_tree.LazyTensorRepresentationTree object at 0x7f0e1c2b0610>)(%other.1, %135) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_tensor.py:1337:0
               ?                                                                                                                       ++ ^^^^
                   %137 : int = prim::Constant[value=-1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:280:0
                   %res.3 : Tensor = aten::squeeze(%136, %137) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:280:0
                   %139 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:281:0
                   %predictive_mean : Tensor = aten::add(%res.3, %test_mean, %139) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:281:0
                   %other : Tensor = prim::Constant[value=<Tensor>]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_tensor.py:1330:0
               -   %covar_inv_quad_form_root : Tensor = ^Matmul(<gpytorch.lazy.lazy_tensor_representation_tree.LazyTensorRepresentationTree object at 0x7f0e23f9d310>)(%other, %135) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_tensor.py:1337:0
               ?                                                                                                                                                 ^^
               +   %covar_inv_quad_form_root : Tensor = ^Matmul(<gpytorch.lazy.lazy_tensor_representation_tree.LazyTensorRepresentationTree object at 0x7f0e23f9d850>)(%other, %135) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_tensor.py:1337:0
               ?                                                                                                                                                 ^^
                   %143 : int = prim::Constant[value=-1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:338:0
                   %144 : int = prim::Constant[value=-2]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:338:0
                   %145 : Tensor = aten::transpose(%covar_inv_quad_form_root, %143, %144) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:338:0
                   %146 : Tensor = prim::Constant[value={-1}]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:338:0
                   %147 : Tensor = aten::mul(%145, %146) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py:338:0
                   %148 : int = prim::Constant[value=-2]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:198:0
                   %149 : int = aten::size(%x1.7, %148) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:198:0
                   %150 : Tensor = prim::NumToTensor(%149)
                   %151 : Tensor = prim::Constant[value={1}]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:198:0
                   %num_rows : Tensor = aten::mul(%150, %151) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:198:0
                   %153 : int = aten::Int(%num_rows)
                   %154 : int[] = prim::ListConstruct(%35)
                   %155 : Tensor = aten::view(%predictive_mean, %154) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_gp.py:322:0
                   %156 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_gp.py:322:0
                   %mean : Tensor = aten::contiguous(%155, %156) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/models/exact_gp.py:322:0
                   %158 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/distributions/multivariate_normal.py:46:0
                   %159 : int = aten::size(%mean, %158) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/distributions/multivariate_normal.py:46:0
                   %160 : Tensor = prim::NumToTensor(%159)
                   %161 : int = aten::Int(%160)
                   %162 : int = aten::Int(%160)
                   %220 : int = prim::Constant[value=20](), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %221 : int = prim::Constant[value=1](), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %222 : Tensor = aten::softplus(%raw_lengthscale, %221, %220), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %x1 : Tensor = aten::div(%x1.7, %222) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/rbf_kernel.py:81:0
                   %223 : int = prim::Constant[value=20](), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %224 : int = prim::Constant[value=1](), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %225 : Tensor = aten::softplus(%raw_lengthscale, %224, %223), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %166 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:314:0
                   %167 : int = aten::size(%x1, %166) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:314:0
                   %168 : Tensor = prim::NumToTensor(%167)
                   %169 : int = aten::Int(%168)
                   %170 : int[] = prim::ListConstruct(%169)
                   %171 : int = prim::Constant[value=6]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:314:0
                   %172 : NoneType = prim::Constant()
                   %173 : Device = prim::Constant[value="cpu"]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:314:0
                   %174 : bool = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:314:0
                   %dist_mat.5 : Tensor = aten::zeros(%170, %171, %172, %173, %174) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/kernel.py:314:0
                   %176 : Tensor = prim::Constant[value={-2}]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/rbf_kernel.py:9:0
                   %dist_mat : Tensor = aten::div_(%dist_mat.5, %176) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/rbf_kernel.py:9:0
                   %orig_output : Tensor = aten::exp_(%dist_mat) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/rbf_kernel.py:9:0
                   %226 : int = prim::Constant[value=20](), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %227 : int = prim::Constant[value=1](), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %outputscales.5 : Tensor = aten::softplus(%raw_outputscale, %227, %226), scope: __module.gp.covar_module.raw_outputscale_constraint._transform # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/torch/nn/modules/activation.py:806:0
                   %180 : int = prim::Constant[value=-1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/scale_kernel.py:97:0
                   %outputscales : Tensor = aten::unsqueeze(%outputscales.5, %180) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/scale_kernel.py:97:0
                   %res : Tensor = aten::mul(%orig_output, %outputscales) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/kernels/scale_kernel.py:98:0
                   %183 : int[] = prim::ListConstruct(%153)
                   %184 : Tensor = aten::view(%res, %183) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:270:0
                   %185 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:270:0
                   %186 : Tensor = aten::contiguous(%184, %185) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:270:0
                   %187 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/sum_lazy_tensor.py:95:0
                   %188 : Tensor = aten::contiguous(%186, %187) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/sum_lazy_tensor.py:95:0
                   %189 : Tensor = prim::Constant[value={0}]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/sum_lazy_tensor.py:95:0
                   %190 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/sum_lazy_tensor.py:95:0
                   %191 : Tensor = aten::add(%188, %189, %190) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/sum_lazy_tensor.py:95:0
                   %192 : int = prim::Constant[value=-1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/matmul_lazy_tensor.py:108:0
                   %193 : int = prim::Constant[value=-2]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/matmul_lazy_tensor.py:108:0
                   %194 : Tensor = aten::transpose(%147, %192, %193) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/matmul_lazy_tensor.py:108:0
                   %195 : Tensor = aten::mul(%covar_inv_quad_form_root, %194) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/matmul_lazy_tensor.py:108:0
                   %196 : int = prim::Constant[value=-1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/matmul_lazy_tensor.py:108:0
                   %197 : int[] = prim::ListConstruct(%196)
                   %198 : bool = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/matmul_lazy_tensor.py:108:0
                   %199 : NoneType = prim::Constant()
                   %200 : Tensor = aten::sum(%195, %197, %198, %199) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/matmul_lazy_tensor.py:108:0
                   %201 : int = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/sum_lazy_tensor.py:95:0
                   %202 : Tensor = aten::contiguous(%200, %201) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/sum_lazy_tensor.py:95:0
                   %203 : int = prim::Constant[value=1]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/sum_lazy_tensor.py:95:0
                   %diag.1 : Tensor = aten::add(%191, %202, %203) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/lazy/sum_lazy_tensor.py:95:0
                   %205 : int[] = prim::ListConstruct(%162)
                   %diag : Tensor = aten::view(%diag.1, %205) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/distributions/multivariate_normal.py:250:0
                   %207 : int[] = prim::ListConstruct(%161)
                   %208 : bool = prim::Constant[value=0]() # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/distributions/multivariate_normal.py:251:0
                   %variance : Tensor = aten::expand(%diag, %207, %208) # /home/sebastian/.virtualenvs/gpytorch-debug/lib/python3.9/site-packages/gpytorch/distributions/multivariate_normal.py:251:0
                   %210 : (Tensor, Tensor) = prim::TupleConstruct(%mean, %variance)
                   return (%210)
       First diverging operator:
       Node diff:
               - %gp : __torch__.ExactGPModel = prim::GetAttr[name="gp"](%self.1)
               + %gp : __torch__.___torch_mangle_15.ExactGPModel = prim::GetAttr[name="gp"](%self.1)
               ?                +++++++++++++++++++

Expected Behavior

I expect the JIT command to work for exact GP models with more than 500 observations.

System information

Please complete the following information:

  • GPyTorch Version 1.5.1
  • PyTorch Version 1.10.0
  • Computer OS Ubuntu 21.4

Additional context

Add any other context about the problem here.

SebastianCallh avatar Oct 29 '21 08:10 SebastianCallh

@jacobrgardner any thoughts?

gpleiss avatar Nov 08 '21 14:11 gpleiss

This is probably related to the total size of train and test crossing the max_eager_kernel_size boundary: https://github.com/cornellius-gp/gpytorch/blob/ade5db8df1f8e4acbc64ec1dd8809e58506827f1/gpytorch/settings.py#L395

Basically, we added an optimization a while back that unless the kernel matrix we're trying to make is larger than 512x512, we don't lazily evaluate kernels to avoid the extra python overhead of lazy tensors. Since we concatenate train and test points together, I'd expect the boundary to be exactly 512, after which we start lazily evaluating kernels to avoid unnecessary computation.

Let me look into whether there's an easy solution.

jacobrgardner avatar Nov 10 '21 03:11 jacobrgardner

Hey @jacobrgardner, I am currently facing the same issue as @SebastianCall (however, with my own dataset and am posting here since it is the same issue) Have either of you found a solution? Thanks.

GStechschulte avatar Apr 04 '22 10:04 GStechschulte

Hi @GStechschulte no, I am still stuck on this. It's a real showstopper as well for me, since I am not able to deploy a model without JITing it. Any news @jacobrgardner?

SebastianCallh avatar May 19 '22 09:05 SebastianCallh

Hey @SebastianCallh, in the coming summer months I will have time to go on a deep dive into this issue. I will keep this thread updated.

GStechschulte avatar May 19 '22 10:05 GStechschulte

Hi! So it's been a while. I've tested the code on newer releases (up to 1.9.1) but still get the same error so I figured I'd check in and see if there has been any internal progress made on this

SebastianCallh avatar Feb 15 '23 15:02 SebastianCallh

I also experience the very same issue.

ypuzikov avatar Feb 17 '23 09:02 ypuzikov

Here are my five cents on this issue.

I think what @jacobrgardner mentioned above about size crossing the max_eager_kernel_size is exactly the case. Basically, the torch.jit.trace error happens when torch traced MeanVarModelWrapper(model) twice and got two different graphs. I checked the text representation for the two graphs, the difference is on the address of the linear_operator we created:

_17 = ^Matmul(<linear_operator.operators.linear_operator_representation_tree.LinearOperatorRepresentationTree object at 0x000001D652BCBF40>)(other, _16)\n _17 = ^Matmul(<linear_operator.operators.linear_operator_representation_tree.LinearOperatorRepresentationTree object at 0x000001E7D4B6A8F0>)(other, _16)\n

While when the input size is small, the code relies on torch tensor instead of linear_operator, and the trace works well.

And the difference of use/not use linear_operator comes at https://github.com/cornellius-gp/gpytorch/blob/1c743fafe9d6de7814bdbd90722020fd486f1a55/gpytorch/models/exact_prediction_strategies.py#L276

# For efficiency - we can make things more efficient
if joint_covar.size(-1) <= settings.max_eager_kernel_size.value():
    test_covar = joint_covar[..., self.num_train :, :].to_dense()
    test_test_covar = test_covar[..., self.num_train :]
    test_train_covar = test_covar[..., : self.num_train]
else:
    test_test_covar = joint_covar[..., self.num_train :, self.num_train :]
    test_train_covar = joint_covar[..., self.num_train :, : self.num_train]

If the total size of train + test is less than settings.max_eager_kernel_size.value(), a torch tensor will be used by calling .to_dense() instead of a LazyEvaluatedKernelTensor.

As for possible solutions, one brutal way is to force the max_eager_kernel_size.value() to the total size of train and test by

with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.trace_mode(), gpytorch.settings.max_eager_kernel_size(train_size + test_size):
    model.eval()
    test_x = torch.linspace(0, 1, test_size)

Another way that is also brutal is when calling exact_prediction function, use the torch tensor scheme if gpytorch.settings.trace_mode.on() is true.

But none of these two ways take advantage of the LazyEvaluatedKernelTensor, I'm not sure how to make the lazy tensors and the linear_operator behind it to work with the tracing in torch.

CY-Zhang avatar Aug 21 '23 04:08 CY-Zhang