probdiffeq
probdiffeq copied to clipboard
Scheduled Reduction in Output Scale Issue
In this new paper, the authors argue that the diffusion parameter (output scale in probdiff) should be set on a decreasing schedule (rather than learned in the case of Fenrir).
Diffusion Tempering Improves Parameter Estimation with Probabilistic Integrators for Ordinary Differential Equations https://arxiv.org/pdf/2402.12231.pdf
Now, I am only interested in the terminal values, so I usually calculate log_marginal_likelihood_terminal_values.
If I follow the schedule described in the paper and I keep the standard_deviation parameter for log_marginal_likelihood_terminal_values constant throughout, all is good.
However, I usually also optimise for the standard_deviation of the observations as part of overall learning process (standard for GP regression). This is where I now run into to sporadic nan issues. Specifically,
line 254, in __call__
ell = solution.log_marginal_likelihood_terminal_values(
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/solvers/solution.py", line 98, in log_marginal_likelihood_terminal_values
_corrected, logpdf = _condition_and_logpdf(rv, u, model)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/solvers/solution.py", line 103, in _condition_and_logpdf
observed, conditional = impl.conditional.revert(rv, model)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/impl/isotropic/_conditional.py", line 40, in revert
r_ext_p, (r_bw_p, gain) = cholesky_util.revert_conditional(
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/util/cholesky_util.py", line 98, in revert_conditional
R = control_flow.cond(
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/backend/control_flow.py", line 58, in cond
return jax.lax.cond(use_true_func, true_func, false_func, *operands)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in jit(cond).
The value of learned standard_derivation term and the scale_output at the time of the crash aren't doing anything crazy. My current setup is far too complicated for a quick MVE, so I am wondering if there is any advice of about exactly what this part of the code is doing and perhaps how I can try and figure out how to avoid the dreaded nan.
My initial guess would be the following: the part of the code you linked computes a QR decomposition of a stack of covariance matrices, and this (or any) QR decomposition is not reverse-mode differentiable as soon as one of the covariance matrices is rank deficient.
This is why the call to QR is embedded into a cond
, which checks whether one of the covariances is exactly zero. Maybe you do not have an exactly-zero covariance, but one that is rank deficient in a more subtle way.
What do you think?
Hi :) This issue has not received any updates for a while, and without any open todos, I'll close it. But please feel free to reopen if something is still up!