optax icon indicating copy to clipboard operation
optax copied to clipboard

Add axis and where arguments to loss functions.

Open carlosgmartin opened this issue 10 months ago • 11 comments

#902

carlosgmartin avatar Apr 08 '24 02:04 carlosgmartin

Thanks @carlosgmartin, could you add tests? Also you will need to wait for #916 to pass.

vroulet avatar Apr 08 '24 22:04 vroulet

@carlosgmartin , there are some conflicts with main, do you mind updating the pull request? thanks

fabianp avatar Jun 10 '24 07:06 fabianp

@vroulet @fabianp Let me know if you'd like me to make any other changes.

carlosgmartin avatar Aug 11 '24 20:08 carlosgmartin

hey @carlosgmartin , thanks for the ping and apologies for the late reply, most of the team is on vacation 🏖️

The addition of the axis argument looks good to me, and the use of an axis kwarg if fairly common in numpy-like functions. ✅

Regarding the "where" argument however, I haven't seen it yet in other libraries. Do you know of any numpy/jax/flax/etc. functions that admit a "where" or a "mask" or similar kwarg? I just want to make sure that our API is as similar as possible to other libraries that have already implemented similar functionality.

fabianp avatar Aug 19 '24 08:08 fabianp

Most reduction functions in numpy/jax take a where argument for masking. For example:

Conceptually, it makes sense that any reduction function should take both axis and where arguments.

carlosgmartin avatar Aug 19 '24 19:08 carlosgmartin

excellent, thanks for the info!

fabianp avatar Aug 20 '24 14:08 fabianp

one more thing: could you please add the tag .. versionchanged:: 0.2.4 to the docstrings of the functions you've changed explaining the change? See for example here for an example: https://github.com/google-deepmind/optax/blob/main/optax/schedules/_inject.py#L114

fabianp avatar Aug 20 '24 14:08 fabianp

also, the new kwargs should be described in the function docstring (on top of adding the versionchanged tag)

fabianp avatar Aug 20 '24 14:08 fabianp

@fabianp Done.

carlosgmartin avatar Aug 20 '24 21:08 carlosgmartin

thanks, we're almost there. Please add type annotations for these two new kwargs. These are likely

where: chex.Array | None = None,
axis: int | tuple[int, ...] | None = -1

fabianp avatar Aug 21 '24 08:08 fabianp

@fabianp Done.

carlosgmartin avatar Aug 21 '24 18:08 carlosgmartin