optimistix
optimistix copied to clipboard
Efficient NewtonCG Implementation
Hi all, thanks for the phenomenal library. We're already using it in several statistical genetics methods in my group!
I've been porting over some older code of mine to use optimistix, rather than hand-rolled inference procedures and could use some advice. Currently, I am performing some variational inference using a mix of closed-form updates for variational parameters, as well as gradient-based updates for some hyperparameters. It -roughly- works like,
while True:
eval_f = jax.value_and_grad(_infer, has_aux=True)
((value, var_params), gradient) = eval_f(hyper_param, var_params, data)
hyper_param = hyper_param + learning_rate * gradient
if converged:
break
I'd -like- to retool the above to not only report the current value, aux values (i.e. updated variational parameters), and gradient wrt hyper param, but return a -hvp- function that could be used in a Newton CG like step in Optimistix. I know of the new minimize
function, but what isn't clear is how to set up the scenario to not only report gradients, but also return a hvp
function internally without having to take two additional passes over the graph (i.e. once for value and grad, another two for hvp => forward + backward).
Is this doable? Apologies if this is somewhat nebulous--I'm happy to clarify.