lineax icon indicating copy to clipboard operation
lineax copied to clipboard

LSMR init overflow

Open GJBoth opened this issue 1 month ago • 4 comments

Thanks for implementing the LSMR solver - I've been using it for matrix-free Levenberg-Marquardt and it seems to work really well! There's one issue with the initialisation however - one of the initial values gets initialised as 1e100 leading to overflows.. The origin of this value seems to be the scipy implementation, which runs in 64bit by default - but Jax runs in 32. It doesn't seem to affect the result of the solver, though.

Simply changing this line to:

minrbar=jnp.inf

solves the issue for me.

GJBoth avatar Nov 12 '25 19:11 GJBoth

This looks reasonable to me, @f0uriest and @PTNobel do you concur? Not sure if this would then require a jnp.where or some other type of safety mechanism in the body of the solver + if a smaller numerical value than 1e100 would be preferable here.

johannahaffner avatar Nov 13 '25 21:11 johannahaffner

Yeah I'm 99.9% sure using jnp.inf would be fine, I don't think it needs any where etc. The only possible issue I could see is if using jax.debug_infs you might get a false positive because of the hard-coded value? Could also use jnp.finfo(x.dtype).max?

f0uriest avatar Nov 13 '25 22:11 f0uriest

This just reminded me of this issue: https://github.com/patrick-kidger/optimistix/pull/186 - let's try to avoid using a jnp.something default value. How about we set minrbar to the highest value that does not overflow in 32-bit, so 1e38 if this is a float, and 1e9 if it is an integer.

johannahaffner avatar Nov 13 '25 22:11 johannahaffner

I think any big number should be fine ...

PTNobel avatar Nov 14 '25 02:11 PTNobel