equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Second order updates

Open BabaYara opened this issue 8 months ago • 8 comments

Hello Patrick,

I hope you're doing well. I've been working with the Equinox package, and I appreciate the effort that's gone into it.

Currently, I'm working on implementing a second-order Taylor optimization of neural networks. A vital part of this implementation requires computing both the gradients and Hessians of the flattened parameters with respect to the loss function. This will allow me to define the updates for the model more effectively as below.

Here's a basic outline of what I'm trying to achieve:

` def loss(params, shapes, input_data, labels):
return loss_0(params, shapes, input_data, labels)

grad_f = jax.grad(loss) hessian_f = jax.hessian(loss) lr = 0.9

grad_params = grad_f(params, shapes, input_data, labels) hessian_params = hessian_f(params, shapes, input_data, labels)
delta_x = solve(hessian_params, grad_params) params += (lr * delta_x) ` Could you please guide me on whether Equinox supports this kind of operation or if there's a recommended approach within the broader Jax ecosystem to achieve this?

Thanks in advance for your assistance.

BabaYara avatar Oct 08 '23 17:10 BabaYara