jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Avoid state.aux=None in state returned by initial_state

Open mblondel opened this issue 4 years ago • 3 comments

Currently, when has_aux=True, state.aux is None when state is returned by init_state and state.aux is equal to fun(params, *args, **kwargs)[1] when state is returned by update. This is problematic as it can trigger a jit recompilation. One way would be to set state.aux to some dummy values of the correct type when returned by init_state.

mblondel avatar Dec 06 '21 22:12 mblondel

Is there a better way than calling self.fun_with_aux using parameters of init_state just to retrieve the aux value ? That would do the trick; the only drawback is to pay the price of calling fun_with_aux "for nothing".

Algue-Rythme avatar Dec 10 '21 11:12 Algue-Rythme

One idea would be to trace through fun in order to retrieve the return types (without actually evaluating fun) but not sure if it's feasible.

mblondel avatar Dec 10 '21 12:12 mblondel

https://jax.readthedocs.io/en/latest/jax.html#jax.eval_shape (thanks to @josipd for the tip)

mblondel avatar Dec 10 '21 14:12 mblondel