jaxopt
jaxopt copied to clipboard
diag(JTJ) can be more efficient
In LM method, the max(diag(JTJ))
is used to set the damping factor. As per option 2 in https://github.com/google/jax/issues/19711 it can be made more efficient than currrently implemented. I discovered this when I hit some OOM problems with jaxopt's LM.