equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Have `get_state`/`set_state` support setting values with compatible batch axes

Open patrick-kidger opened this issue 2 years ago • 0 comments

At the moment, get_state and set_state demand that their arguments have precisely the same shape, dtype, and choice of batch axes.

It would make sense to allow some compatibility between different kinds of batch axes, for example because of it taking a few fixed-point iterations for all of the vmap'd batch axes to flow through a model.

patrick-kidger avatar Mar 29 '22 22:03 patrick-kidger