gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

[Feature Request] Simultaneous validation prediction during exact marginal log likelihood calculation

Open Turakar opened this issue 2 years ago • 2 comments

🚀 Feature Request

Accelerate validation prediction during training by reusing the solve from CG/Cholesky.

Motivation

Is your feature request related to a problem? Please describe.

The exact marginal log likelihood is not always guaranteed to combat overfitting, and as such, one might want to monitor the performance on a hold-out dataset to perform early stopping. Currently, this requires to do a fully separate prediction step, which might be computationally expensive.

Pitch

It would be nice to re-use computations of the log likelihood calculation. Parts which could be re-used (please complete this list with your ideas!):

  • Training covariance matrix
  • Solve of the training covariance matrix against the observed values ($a = K^{-1}y$)

Describe the solution you'd like

  1. Extend LinearOperator.inv_quad_logdet() to also return the matrix solve without gradients (ctx.mark_non_differentiable).
  2. Extend the forward pass in ExactGP to take additional validation data (or pass this data in the constructor). In the forward pass, the only additional cost is the validation prior evaluation (mean & covar) and the matmul with the solve.

Describe alternatives you've considered

Somehow cache the relevant values using @cached. I like the explicit approach more, as caching always comes with the question when to acquire or release the cache, which can make memory usage unpredictable.

Are you willing to open a pull request?

I actually have a working implementation for this for a special case, However, I still want to take care of #2288 first, and this proposal might need some design discussions, as this definitely will change the API of some often used functions.

Additional context

It might be possible to extend this idea from exact to variational GPs. Of course, this approach only allows for the calculation of the mean prediction without a variance estimate in its current form.

Turakar avatar Apr 20 '23 11:04 Turakar

@Turakar I know that @JonathanWenger has plans for a prediction strategy refactor that would allow for some of this re-use.

gpleiss avatar May 26 '23 14:05 gpleiss

cc @SebastianAment, @sdaulton, @dme65 for any potential ideas here

Balandat avatar May 27 '23 00:05 Balandat