Request: more permissible shape checking for regression losses (equal -> broadcastable)
Hi,
Despite what is written in the documentation optax.losses.squared_error i.e. "targets: a vector with shape broadcastable to that of predictions", I noticed that the squared_error loss in _regression.py and all the loss that are derived from it perform a strong chex.assert_equal_shape sanity check instead of chex.assert_is_broadcastable. The implementation copy-pasted below includes a comment implying that this is on-purpose.
def squared_error(
predictions: chex.Array,
targets: Optional[chex.Array] = None,
) -> chex.Array:
"""..."""
chex.assert_type([predictions], float)
if targets is not None:
# Avoid broadcasting logic for "-" operator.
chex.assert_equal_shape((predictions, targets))
errors = predictions - targets if targets is not None else predictions
return errors**2
It comes from commit b517edd. Is there a specific reason to avoid broadcasting? Is it possible to revert or lighten the check to allow broadcastable arrays (for conveniency)?
I'd like to help implement this fix. I can submit a PR with a backward-compatible solution using an allow_broadcasting parameter + comprehensive tests. Would the maintainers be open to this approach? @niamorg
Thanks for the report!
We're in the process of simplifying the library. Given this is a very simple loss (unlike e.g., huber), I'd prefer to leave this as is, with the check in there, as it might encourage simply using 0.5 * (x - y) ** 2 which is more explicit.
Thank you for your quick answers. When you say you are in the process of simplifying the library, you mean remove simple losses such as the squared_error? I understand you may desire not maintain such simples losses and encourage explicitness.
That's correct, but we don't plan on removing anything people depend on. We'd just like to focus on gradient transformation algorithms first.
We can close this issue @rdyro right ?