jaxopt
jaxopt copied to clipboard
Avoid state.aux=None in state returned by initial_state
Currently, when has_aux=True, state.aux is None when state is returned by init_state and state.aux is equal to fun(params, *args, **kwargs)[1] when state is returned by update. This is problematic as it can trigger a jit recompilation. One way would be to set state.aux to some dummy values of the correct type when returned by init_state.
Is there a better way than calling self.fun_with_aux using parameters of init_state just to retrieve the aux value ? That would do the trick; the only drawback is to pay the price of calling fun_with_aux "for nothing".
One idea would be to trace through fun in order to retrieve the return types (without actually evaluating fun) but not sure if it's feasible.
https://jax.readthedocs.io/en/latest/jax.html#jax.eval_shape (thanks to @josipd for the tip)