probdiffeq icon indicating copy to clipboard operation
probdiffeq copied to clipboard

Scheduled Reduction in Output Scale Issue

Open adam-hartshorne opened this issue 11 months ago • 1 comments

In this new paper, the authors argue that the diffusion parameter (output scale in probdiff) should be set on a decreasing schedule (rather than learned in the case of Fenrir).

Diffusion Tempering Improves Parameter Estimation with Probabilistic Integrators for Ordinary Differential Equations https://arxiv.org/pdf/2402.12231.pdf

Now, I am only interested in the terminal values, so I usually calculate log_marginal_likelihood_terminal_values.

If I follow the schedule described in the paper and I keep the standard_deviation parameter for log_marginal_likelihood_terminal_values constant throughout, all is good.

However, I usually also optimise for the standard_deviation of the observations as part of overall learning process (standard for GP regression). This is where I now run into to sporadic nan issues. Specifically,

line 254, in __call__
    ell = solution.log_marginal_likelihood_terminal_values(
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/solvers/solution.py", line 98, in log_marginal_likelihood_terminal_values
    _corrected, logpdf = _condition_and_logpdf(rv, u, model)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/solvers/solution.py", line 103, in _condition_and_logpdf
    observed, conditional = impl.conditional.revert(rv, model)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/impl/isotropic/_conditional.py", line 40, in revert
    r_ext_p, (r_bw_p, gain) = cholesky_util.revert_conditional(
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/util/cholesky_util.py", line 98, in revert_conditional
    R = control_flow.cond(  
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/backend/control_flow.py", line 58, in cond
    return jax.lax.cond(use_true_func, true_func, false_func, *operands)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in jit(cond).

The value of learned standard_derivation term and the scale_output at the time of the crash aren't doing anything crazy. My current setup is far too complicated for a quick MVE, so I am wondering if there is any advice of about exactly what this part of the code is doing and perhaps how I can try and figure out how to avoid the dreaded nan.

adam-hartshorne avatar Mar 01 '24 01:03 adam-hartshorne

My initial guess would be the following: the part of the code you linked computes a QR decomposition of a stack of covariance matrices, and this (or any) QR decomposition is not reverse-mode differentiable as soon as one of the covariance matrices is rank deficient.

This is why the call to QR is embedded into a cond, which checks whether one of the covariances is exactly zero. Maybe you do not have an exactly-zero covariance, but one that is rank deficient in a more subtle way.

What do you think?

pnkraemer avatar Mar 01 '24 07:03 pnkraemer

Hi :) This issue has not received any updates for a while, and without any open todos, I'll close it. But please feel free to reopen if something is still up!

pnkraemer avatar Oct 16 '24 06:10 pnkraemer