optax icon indicating copy to clipboard operation
optax copied to clipboard

Request: more permissible shape checking for regression losses (equal -> broadcastable)

Open niamorg opened this issue 7 months ago • 4 comments

Hi,

Despite what is written in the documentation optax.losses.squared_error i.e. "targets: a vector with shape broadcastable to that of predictions", I noticed that the squared_error loss in _regression.py and all the loss that are derived from it perform a strong chex.assert_equal_shape sanity check instead of chex.assert_is_broadcastable. The implementation copy-pasted below includes a comment implying that this is on-purpose.

def squared_error(
    predictions: chex.Array,
    targets: Optional[chex.Array] = None,
) -> chex.Array:
  """..."""
  chex.assert_type([predictions], float)
  if targets is not None:
    # Avoid broadcasting logic for "-" operator.
    chex.assert_equal_shape((predictions, targets))
  errors = predictions - targets if targets is not None else predictions
  return errors**2

It comes from commit b517edd. Is there a specific reason to avoid broadcasting? Is it possible to revert or lighten the check to allow broadcastable arrays (for conveniency)?

niamorg avatar Aug 22 '25 23:08 niamorg

I'd like to help implement this fix. I can submit a PR with a backward-compatible solution using an allow_broadcasting parameter + comprehensive tests. Would the maintainers be open to this approach? @niamorg

Kushagra481 avatar Aug 25 '25 19:08 Kushagra481

Thanks for the report!

We're in the process of simplifying the library. Given this is a very simple loss (unlike e.g., huber), I'd prefer to leave this as is, with the check in there, as it might encourage simply using 0.5 * (x - y) ** 2 which is more explicit.

rdyro avatar Aug 25 '25 21:08 rdyro

Thank you for your quick answers. When you say you are in the process of simplifying the library, you mean remove simple losses such as the squared_error? I understand you may desire not maintain such simples losses and encourage explicitness.

niamorg avatar Aug 27 '25 09:08 niamorg

That's correct, but we don't plan on removing anything people depend on. We'd just like to focus on gradient transformation algorithms first.

rdyro avatar Sep 01 '25 16:09 rdyro

We can close this issue @rdyro right ?

haroon0x avatar Dec 03 '25 15:12 haroon0x