equinox
equinox copied to clipboard
Have `get_state`/`set_state` support setting values with compatible batch axes
At the moment, get_state
and set_state
demand that their arguments have precisely the same shape, dtype, and choice of batch axes.
It would make sense to allow some compatibility between different kinds of batch axes, for example because of it taking a few fixed-point iterations for all of the vmap'd batch axes to flow through a model.