jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

diag(JTJ) can be more efficient

Open Joshuaalbert opened this issue 1 year ago • 0 comments

In LM method, the max(diag(JTJ)) is used to set the damping factor. As per option 2 in https://github.com/google/jax/issues/19711 it can be made more efficient than currrently implemented. I discovered this when I hit some OOM problems with jaxopt's LM.

Joshuaalbert avatar Feb 13 '24 23:02 Joshuaalbert