jaxopt
jaxopt copied to clipboard
LevenbergMarquardt and pytrees
I propose to replace the JAX NumPy operations in LevembergMarquardt
with the corresponding ones in tree_utils
to address issues #505 and #579. Now, the snippet in issue #505 appears to run correctly, both with and without geodesic acceleration (using the solver solve_cg
).
However, QR, LU, and Cholesky still fail since they require the flattened versions of both the Jacobian and parameters.
Regarding the computation of the initial value of the damping_factor
, using self.damping_parameter * jnp.max(jtj_diag)
requires materializing the full identity matrix. Perhaps, for large problems like the one in Issue #579, it would be useful to include the option for the user to choose an initial damping_factor
without calculating jtj_diag
? (In the same way of the original paper by Marquardt https://www.jstor.org/stable/2098941, p.438)