optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Efficient NewtonCG Implementation

Open quattro opened this issue 7 months ago • 3 comments

Hi all, thanks for the phenomenal library. We're already using it in several statistical genetics methods in my group!

I've been porting over some older code of mine to use optimistix, rather than hand-rolled inference procedures and could use some advice. Currently, I am performing some variational inference using a mix of closed-form updates for variational parameters, as well as gradient-based updates for some hyperparameters. It -roughly- works like,

while True:
  eval_f = jax.value_and_grad(_infer, has_aux=True)
  ((value, var_params), gradient) = eval_f(hyper_param, var_params, data)
  hyper_param = hyper_param + learning_rate * gradient
  if converged:
    break

I'd -like- to retool the above to not only report the current value, aux values (i.e. updated variational parameters), and gradient wrt hyper param, but return a -hvp- function that could be used in a Newton CG like step in Optimistix. I know of the new minimize function, but what isn't clear is how to set up the scenario to not only report gradients, but also return a hvp function internally without having to take two additional passes over the graph (i.e. once for value and grad, another two for hvp => forward + backward).

Is this doable? Apologies if this is somewhat nebulous--I'm happy to clarify.

quattro avatar Nov 15 '23 20:11 quattro