optax icon indicating copy to clipboard operation
optax copied to clipboard

Add segmentation based losses

Open aymuos15 opened this issue 6 months ago • 4 comments

Motivation - Segmentation is a pretty common CV (Computer Vision) task so I imagine it may increase adoption? I am trying to use it myself and hence the issue here.

The two most commonly used losses are Dice loss and Dice + CE loss (I am biased to medical imaging, so not sure about general CV).

Would this be on the agenda of the repo? Concurrently, would this be okay for a PR? I would then like to work on it then.

Below is a rough implementation of code and tests based on the existing losses for Dice and Dice + CE Loss.

Mock Code

from typing import Union

import chex
import jax
import jax.numpy as jnp


def dice_loss(
    logits: chex.Array,
    labels: chex.Array,
    smooth: float = 1e-6,
    axis: Union[int, tuple[int, ...], None] = -1,
    where: Union[chex.Array, None] = None,
) -> chex.Array:
    """Computes the Dice loss between logits and labels.

    Args:
      logits: Unnormalized log probabilities, with shape [..., num_classes].
      labels: One-hot encoded labels or binary labels, with shape
        [..., num_classes].
      smooth: Smoothing factor to avoid division by zero. Defaults to 1e-6.
      axis: Axis or axes along which to compute. Defaults to -1.
      where: Elements to include in the computation. Defaults to None.

    Returns:
      Dice loss between each prediction and the corresponding target
      distributions, with shape [...].

    Examples:
      >>> import optax
      >>> import jax.numpy as jnp
      >>> jnp.set_printoptions(precision=4)
      >>> # Binary classification
      >>> logits = jnp.array([1.0, -1.0, 1.0])
      >>> labels = jnp.array([1.0, 0.0, 1.0])
      >>> print(optax.dice_loss(logits, labels))
      0.0
      >>> # Multi-class classification
      >>> logits = jnp.array([[1.0, -1.0, 2.0], [1.0, 2.0, -1.0]])
      >>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
      >>> print(optax.dice_loss(logits, labels))
      0.5
    """
    chex.assert_type([logits], float)
    
    # Convert logits to probabilities
    probs = jax.nn.sigmoid(logits) if logits.shape[-1] == 1 else jax.nn.softmax(logits, axis=-1)
    
    # Compute intersection and cardinality
    intersection = jnp.sum(probs * labels, axis=axis)
    cardinality = jnp.sum(probs + labels, axis=axis)
    
    # Compute Dice coefficient per class
    dice_per_class = (2. * intersection + smooth) / (cardinality + smooth)
    
    # Handle case where both predictions and labels are empty
    # (cardinality == 0) means both predictions and labels are zero
    dice_per_class = jnp.where(cardinality > 0, dice_per_class, 1.0)
    
    # Average across classes
    dice = jnp.mean(dice_per_class, axis=-1 if axis == -1 else None)
    
    # Apply mask if provided
    if where is not None:
        dice = jnp.where(where, dice, 0.0)
    
    return 1.0 - dice


def dice_ce_loss(
    logits: chex.Array,
    labels: chex.Array,
    *,
    dice_weight: float = 0.5,
    smooth: float = 1e-6,
    axis: Union[int, tuple[int, ...], None] = -1,
    where: Union[chex.Array, None] = None,
) -> chex.Array:
    """Computes a combination of Dice and Cross-Entropy loss.

    Args:
      logits: Unnormalized log probabilities, with shape [..., num_classes].
      labels: One-hot encoded labels or binary labels, with shape
        [..., num_classes].
      dice_weight: Weight for the Dice loss component (between 0 and 1). The weight
        for CE loss will be (1 - dice_weight). Defaults to 0.5.
      smooth: Smoothing factor to avoid division by zero in Dice loss.
        Defaults to 1e-6.
      axis: Axis or axes along which to compute. Defaults to -1.
      where: Elements to include in the computation. Defaults to None.

    Returns:
      Combined loss between Dice and Cross-Entropy.

    Examples:
      >>> import optax
      >>> import jax.numpy as jnp
      >>> jnp.set_printoptions(precision=4)
      >>> # Binary classification
      >>> logits = jnp.array([1.0, -1.0, 1.0])
      >>> labels = jnp.array([1.0, 0.0, 1.0])
      >>> # Equal weight to both losses
      >>> print(optax.dice_ce_loss(logits, labels, dice_weight=0.5))
      0.3133
      >>> # More weight to Dice loss
      >>> print(optax.dice_ce_loss(logits, labels, dice_weight=0.7))
      0.2193
      >>> # Multi-class classification
      >>> logits = jnp.array([[1.0, -1.0, 2.0], [1.0, 2.0, -1.0]])
      >>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
      >>> print(optax.dice_ce_loss(logits, labels, dice_weight=0.5))
      0.75
    """
    chex.assert_type([logits], float)
    if not 0 <= dice_weight <= 1:
        raise ValueError(f"dice_weight must be between 0 and 1, got {dice_weight}")
    
    # Compute Dice loss
    dice = dice_loss(logits, labels, smooth=smooth, axis=axis, where=where)
    
    # Compute appropriate CE loss based on input shape
    if logits.shape[-1] == 1:  # Binary case
        ce = jax.nn.sigmoid_cross_entropy_with_logits(logits, labels)
    else:  # Multi-class case
        ce = jax.nn.softmax_cross_entropy_with_logits(logits, labels)
    
    # Apply mask if provided
    if where is not None:
        ce = jnp.where(where, ce, 0.0)
    
    # Combine losses
    return dice_weight * dice + (1 - dice_weight) * ce

Mock Tests

from absl.testing import absltest
from absl.testing import parameterized
import chex
import jax
import jax.numpy as jnp

from optax.losses import _segmentation


class SegmentationLossTest(chex.TestCase, parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.seed = 42
    self.rtol = 1e-6
    self.atol = 1e-6

  @chex.all_variants
  def test_dice_loss_perfect_prediction(self):
    """Test that dice loss is 0 for perfect predictions."""
    # Test binary case
    logits = jnp.array([10.0, -10.0])  # Perfect prediction for class 0
    labels = jnp.array([1.0, 0.0])
    loss = self.variant(_segmentation.dice_loss)(logits, labels)
    self.assertAlmostEqual(loss, 0.0, delta=1e-6)

    # Test multi-class case
    logits = jnp.array([[10.0, -10.0, -10.0], [-10.0, 10.0, -10.0]])
    labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
    loss = self.variant(_segmentation.dice_loss)(logits, labels)
    self.assertAlmostEqual(loss, 0.0, delta=1e-6)

  @chex.all_variants
  def test_dice_loss_worst_prediction(self):
    """Test that dice loss is 1 for worst possible predictions."""
    # Test binary case
    logits = jnp.array([-10.0, 10.0])  # Worst prediction for class 0
    labels = jnp.array([1.0, 0.0])
    loss = self.variant(_segmentation.dice_loss)(logits, labels)
    self.assertAlmostEqual(loss, 1.0, delta=1e-6)

  @chex.all_variants
  def test_dice_ce_loss_weights(self):
    """Test that dice_ce_loss respects the dice_weight parameter."""
    logits = jnp.array([1.0, -1.0])
    labels = jnp.array([1.0, 0.0])
    
    # Pure dice loss
    dice_only = self.variant(_segmentation.dice_ce_loss)(
        logits, labels, dice_weight=1.0)
    expected_dice = self.variant(_segmentation.dice_loss)(logits, labels)
    self.assertAlmostEqual(dice_only, expected_dice, delta=1e-6)
    
    # Pure CE loss
    ce_only = self.variant(_segmentation.dice_ce_loss)(
        logits, labels, dice_weight=0.0)
    expected_ce = jax.nn.sigmoid_cross_entropy_with_logits(logits, labels).mean()
    self.assertAlmostEqual(ce_only, expected_ce, delta=1e-6)
    
    # Mixed loss
    mixed = self.variant(_segmentation.dice_ce_loss)(
        logits, labels, dice_weight=0.5)
    expected_mixed = 0.5 * expected_dice + 0.5 * expected_ce
    self.assertAlmostEqual(mixed, expected_mixed, delta=1e-6)

  @chex.all_variants
  def test_dice_loss_with_mask(self):
    """Test that the where parameter correctly masks elements."""
    logits = jnp.array([[1.0, -1.0], [1.0, -1.0]])
    labels = jnp.array([[1.0, 0.0], [0.0, 1.0]])
    mask = jnp.array([True, False])
    
    # With mask
    loss = self.variant(_segmentation.dice_loss)(
        logits, labels, where=mask)
    
    # Should be equal to just the first example
    expected = _segmentation.dice_loss(logits[0], labels[0])
    self.assertAlmostEqual(loss, expected, delta=1e-6)

  @chex.all_variants
  def test_dice_ce_loss_gradients(self):
    """Test that gradients exist and are finite."""
    logits = jnp.array([0.5, -0.5])
    labels = jnp.array([1.0, 0.0])
    
    def loss_fn(logits):
      return _segmentation.dice_ce_loss(logits, labels, dice_weight=0.5)
    
    grad_fn = jax.grad(loss_fn)
    grads = self.variant(grad_fn)(logits)
    
    # Check gradients are not NaN or Inf
    self.assertTrue(jnp.all(jnp.isfinite(grads)))
    self.assertNotAlmostEqual(jnp.sum(jnp.abs(grads)), 0.0, delta=1e-6)


if __name__ == "__main__":
  absltest.main()

aymuos15 avatar Jun 24 '25 09:06 aymuos15

I'd like to propose adding Dice Loss and Dice + Cross Entropy (CE) Loss functions to Optax. These are widely used in segmentation tasks, particularly in medical imaging and increasingly in general computer vision workflows.

🔹 Why it's useful:

Dice-based losses are a staple in segmentation benchmarks.

Many JAX users working with segmentation tasks implement these externally — having them in Optax improves consistency and adoption.

🔹 What's included:

dice_loss and dice_ce_loss with:

Binary and multi-class support

Optional where masking

Gradient safety

Unit tests covering:

Perfect/worst predictions

Masking behaviour

Dice/CE combination tuning

Gradient checks

I’ve prepared an initial implementation (code + tests) in the style of Optax and would be happy to polish it for final inclusion. Would it be okay to open a PR for this?

Looking forward to your thoughts!

Best, @YashSachdeva

YashSachdeva avatar Jun 26 '25 12:06 YashSachdeva

@aymuos15 yes, thanks! Feel free to start a PR and tag me as the reviewer!

rdyro avatar Jun 26 '25 18:06 rdyro

@YashSachdeva I don't understand, @aymuos15 already proposed the extension.

rdyro avatar Jun 26 '25 18:06 rdyro

Thanks a lot @rdyro. Will get on with it then.

aymuos15 avatar Jun 26 '25 18:06 aymuos15