optax
optax copied to clipboard
Add axis and where arguments to loss functions.
#902
Thanks @carlosgmartin, could you add tests? Also you will need to wait for #916 to pass.
@carlosgmartin , there are some conflicts with main, do you mind updating the pull request? thanks
@vroulet @fabianp Let me know if you'd like me to make any other changes.
hey @carlosgmartin , thanks for the ping and apologies for the late reply, most of the team is on vacation 🏖️
The addition of the axis argument looks good to me, and the use of an axis kwarg if fairly common in numpy-like functions. ✅
Regarding the "where" argument however, I haven't seen it yet in other libraries. Do you know of any numpy/jax/flax/etc. functions that admit a "where" or a "mask" or similar kwarg? I just want to make sure that our API is as similar as possible to other libraries that have already implemented similar functionality.
Most reduction functions in numpy/jax take a where
argument for masking. For example:
Conceptually, it makes sense that any reduction function should take both axis
and where
arguments.
excellent, thanks for the info!
one more thing: could you please add the tag .. versionchanged:: 0.2.4
to the docstrings of the functions you've changed explaining the change? See for example here for an example: https://github.com/google-deepmind/optax/blob/main/optax/schedules/_inject.py#L114
also, the new kwargs should be described in the function docstring (on top of adding the versionchanged
tag)
@fabianp Done.
thanks, we're almost there. Please add type annotations for these two new kwargs. These are likely
where: chex.Array | None = None,
axis: int | tuple[int, ...] | None = -1
@fabianp Done.