LSMR init overflow
Thanks for implementing the LSMR solver - I've been using it for matrix-free Levenberg-Marquardt and it seems to work really well! There's one issue with the initialisation however - one of the initial values gets initialised as 1e100 leading to overflows.. The origin of this value seems to be the scipy implementation, which runs in 64bit by default - but Jax runs in 32. It doesn't seem to affect the result of the solver, though.
Simply changing this line to:
minrbar=jnp.inf
solves the issue for me.
This looks reasonable to me, @f0uriest and @PTNobel do you concur? Not sure if this would then require a jnp.where or some other type of safety mechanism in the body of the solver + if a smaller numerical value than 1e100 would be preferable here.
Yeah I'm 99.9% sure using jnp.inf would be fine, I don't think it needs any where etc. The only possible issue I could see is if using jax.debug_infs you might get a false positive because of the hard-coded value? Could also use jnp.finfo(x.dtype).max?
This just reminded me of this issue: https://github.com/patrick-kidger/optimistix/pull/186 - let's try to avoid using a jnp.something default value. How about we set minrbar to the highest value that does not overflow in 32-bit, so 1e38 if this is a float, and 1e9 if it is an integer.
I think any big number should be fine ...