gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

[Bug] CG terminated in 1000 iterations with average residual norm XXX which is larger than the tolerance of 1 specified by gpytorch.settings.cg_tolerance.

Open Songloading opened this issue 2 years ago • 1 comments

🐛 Bug

I am not sure this is a bug or something I just made a mistake. I do notice there was a post (issue #1129) very similar to my scenario. In short, the error pops whenever I want to do validation, specifically line

valid_loss = -mll(prediction_val, torch.tensor(Y_val))

I've also checked my input as mentioned in #1129 and I think both the train and validation set are well-conditioned. Two things came to my mind:

  1. I've used the same loss function for both train and validation, though I didn't think it should cause any error.
  2. I used sklearn PolynomialFeatures to create interactions between inputs, which might be an issue mentioned in #1129 (duplicate/highly similar data).

To reproduce

	min_valid_loss = np.inf
	
	for j in range(opt.epoch):
		model.train()
		likelihood.train()
		optimizer.zero_grad()
		output = model(torch.tensor(np.float32(X_train.values)))
		loss = -mll(output, torch.tensor(Y_train).contiguous())
		loss.backward()
		if j % 10 == 0:
			print(f'{site}: Iter{j}/{opt.epoch} {j} trainning day Finished - Loss: {loss.item()}')
		optimizer.step()
		
		if j >= int(opt.epoch*0.9):
			with torch.no_grad():
				model.eval()
				likelihood.eval()
				valid_loss = 0.0
				prediction_val = model(torch.tensor(np.float32(X_val.values)))
				valid_loss = -mll(prediction_val, torch.tensor(Y_val))
				if min_valid_loss > valid_loss:
                                  ## make prediction


** Stack trace/error message **

NumericalWarning: CG terminated in 1000 iterations with average residual norm 36.0455322265625 which is larger than the tolerance of 1 specified by gpytorch.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a gpytorch.settings.max_cg_iterations(value) context.
  warnings.warn(

Expected Behavior

I expect the loss being calculated without any error in validation.

System information

GPyTorch Version - 1.6.0 PyTorch Version - 1.11.0 Computer OS - Ubuntu 18.04

Songloading avatar Jun 22 '22 15:06 Songloading

A small bug in your code: you should be calling mll on model.likelihood(model(x_data)). Your code is calling mll on model(x_data), which means you are not taking into account observational noise. This is likely making everything very ill conditioned.

Additionally, you probably should not be calling mll in eval mode (only training mode). This is likely what is making your data horribly conditioned. Your posterior covariance matrix will likely be very ill conditioned, especially if your posterior variance is overall very low. A better (more well behaved) validation loss is

-prediction_val.to_data_independent_dist.log_prob(torch.tensor(Y_val))`

This does not take into account the joint probabilities in your posterior predictive, but I have found that it works well in practice.

^^ Also, again, make sure that prediction_val calls model.likelihood(model(...)), not just model(...).

gpleiss avatar Jul 26 '22 19:07 gpleiss

Hi @gpleiss, thank you for your answer.

However, it is a little bit confusing to me whether to include the noise or not when training the model (this is, where to do loss = -mll(model(...)) or loss = -mll(model.likelihood(model(...))), as most examples in the docs use the first approach in the training loop (such as cell 10 in the tutorial and the one with KeOps kernels).

For actual predictions (posterior predictive), no doubt that observational noise should be added to preds (hence model.likelihood(model(…)) should be used), but while training and using mll loss, I am unsure.

Am I missing out on anything? Thanks!

julioasotodv avatar Sep 12 '22 00:09 julioasotodv

However, it is a little bit confusing to me whether to include the noise or not when training the model (this is, where to do loss = -mll(model(...)) or loss = -mll(model.likelihood(model(...)))

It should be loss = -mll(model(...)). To be honest, I'm not sure why we made this design decision 5 years ago... I think it was to be consistent with the VariationalELBO objective function.

For actual predictions (posterior predictive), no doubt that observational noise should be added to preds (hence model.likelihood(model(…)) should be used)

Correct. model.likelihood(model(…)) should be used for the predictive posterior, and model(…) should be used for the latent posterior.

gpleiss avatar Sep 22 '22 13:09 gpleiss

However, it is a little bit confusing to me whether to include the noise or not when training the model (this is, where to do loss = -mll(model(...)) or loss = -mll(model.likelihood(model(...)))

It should be loss = -mll(model(...)). To be honest, I'm not sure why we made this design decision 5 years ago... I think it was to be consistent with the VariationalELBO objective function.

Thank you @gpleiss for the clarification! Actually, I looked at the code and I realized that obviously the mll includes the likelihood anyways, so the gradients for the likelihood get correctly "backpropagated" when calling loss=-mll(model(...)), and they would get (incorrectly) "backpropagated" twice if loss=-mll(model.likelihood(model(…))) were to be used.

Thanks a lot for the confirmation! I really like gpytorch 😊

julioasotodv avatar Sep 22 '22 19:09 julioasotodv