Add segmentation based losses
Motivation - Segmentation is a pretty common CV (Computer Vision) task so I imagine it may increase adoption? I am trying to use it myself and hence the issue here.
The two most commonly used losses are Dice loss and Dice + CE loss (I am biased to medical imaging, so not sure about general CV).
Would this be on the agenda of the repo? Concurrently, would this be okay for a PR? I would then like to work on it then.
Below is a rough implementation of code and tests based on the existing losses for Dice and Dice + CE Loss.
Mock Code
from typing import Union
import chex
import jax
import jax.numpy as jnp
def dice_loss(
logits: chex.Array,
labels: chex.Array,
smooth: float = 1e-6,
axis: Union[int, tuple[int, ...], None] = -1,
where: Union[chex.Array, None] = None,
) -> chex.Array:
"""Computes the Dice loss between logits and labels.
Args:
logits: Unnormalized log probabilities, with shape [..., num_classes].
labels: One-hot encoded labels or binary labels, with shape
[..., num_classes].
smooth: Smoothing factor to avoid division by zero. Defaults to 1e-6.
axis: Axis or axes along which to compute. Defaults to -1.
where: Elements to include in the computation. Defaults to None.
Returns:
Dice loss between each prediction and the corresponding target
distributions, with shape [...].
Examples:
>>> import optax
>>> import jax.numpy as jnp
>>> jnp.set_printoptions(precision=4)
>>> # Binary classification
>>> logits = jnp.array([1.0, -1.0, 1.0])
>>> labels = jnp.array([1.0, 0.0, 1.0])
>>> print(optax.dice_loss(logits, labels))
0.0
>>> # Multi-class classification
>>> logits = jnp.array([[1.0, -1.0, 2.0], [1.0, 2.0, -1.0]])
>>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
>>> print(optax.dice_loss(logits, labels))
0.5
"""
chex.assert_type([logits], float)
# Convert logits to probabilities
probs = jax.nn.sigmoid(logits) if logits.shape[-1] == 1 else jax.nn.softmax(logits, axis=-1)
# Compute intersection and cardinality
intersection = jnp.sum(probs * labels, axis=axis)
cardinality = jnp.sum(probs + labels, axis=axis)
# Compute Dice coefficient per class
dice_per_class = (2. * intersection + smooth) / (cardinality + smooth)
# Handle case where both predictions and labels are empty
# (cardinality == 0) means both predictions and labels are zero
dice_per_class = jnp.where(cardinality > 0, dice_per_class, 1.0)
# Average across classes
dice = jnp.mean(dice_per_class, axis=-1 if axis == -1 else None)
# Apply mask if provided
if where is not None:
dice = jnp.where(where, dice, 0.0)
return 1.0 - dice
def dice_ce_loss(
logits: chex.Array,
labels: chex.Array,
*,
dice_weight: float = 0.5,
smooth: float = 1e-6,
axis: Union[int, tuple[int, ...], None] = -1,
where: Union[chex.Array, None] = None,
) -> chex.Array:
"""Computes a combination of Dice and Cross-Entropy loss.
Args:
logits: Unnormalized log probabilities, with shape [..., num_classes].
labels: One-hot encoded labels or binary labels, with shape
[..., num_classes].
dice_weight: Weight for the Dice loss component (between 0 and 1). The weight
for CE loss will be (1 - dice_weight). Defaults to 0.5.
smooth: Smoothing factor to avoid division by zero in Dice loss.
Defaults to 1e-6.
axis: Axis or axes along which to compute. Defaults to -1.
where: Elements to include in the computation. Defaults to None.
Returns:
Combined loss between Dice and Cross-Entropy.
Examples:
>>> import optax
>>> import jax.numpy as jnp
>>> jnp.set_printoptions(precision=4)
>>> # Binary classification
>>> logits = jnp.array([1.0, -1.0, 1.0])
>>> labels = jnp.array([1.0, 0.0, 1.0])
>>> # Equal weight to both losses
>>> print(optax.dice_ce_loss(logits, labels, dice_weight=0.5))
0.3133
>>> # More weight to Dice loss
>>> print(optax.dice_ce_loss(logits, labels, dice_weight=0.7))
0.2193
>>> # Multi-class classification
>>> logits = jnp.array([[1.0, -1.0, 2.0], [1.0, 2.0, -1.0]])
>>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
>>> print(optax.dice_ce_loss(logits, labels, dice_weight=0.5))
0.75
"""
chex.assert_type([logits], float)
if not 0 <= dice_weight <= 1:
raise ValueError(f"dice_weight must be between 0 and 1, got {dice_weight}")
# Compute Dice loss
dice = dice_loss(logits, labels, smooth=smooth, axis=axis, where=where)
# Compute appropriate CE loss based on input shape
if logits.shape[-1] == 1: # Binary case
ce = jax.nn.sigmoid_cross_entropy_with_logits(logits, labels)
else: # Multi-class case
ce = jax.nn.softmax_cross_entropy_with_logits(logits, labels)
# Apply mask if provided
if where is not None:
ce = jnp.where(where, ce, 0.0)
# Combine losses
return dice_weight * dice + (1 - dice_weight) * ce
Mock Tests
from absl.testing import absltest
from absl.testing import parameterized
import chex
import jax
import jax.numpy as jnp
from optax.losses import _segmentation
class SegmentationLossTest(chex.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.seed = 42
self.rtol = 1e-6
self.atol = 1e-6
@chex.all_variants
def test_dice_loss_perfect_prediction(self):
"""Test that dice loss is 0 for perfect predictions."""
# Test binary case
logits = jnp.array([10.0, -10.0]) # Perfect prediction for class 0
labels = jnp.array([1.0, 0.0])
loss = self.variant(_segmentation.dice_loss)(logits, labels)
self.assertAlmostEqual(loss, 0.0, delta=1e-6)
# Test multi-class case
logits = jnp.array([[10.0, -10.0, -10.0], [-10.0, 10.0, -10.0]])
labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
loss = self.variant(_segmentation.dice_loss)(logits, labels)
self.assertAlmostEqual(loss, 0.0, delta=1e-6)
@chex.all_variants
def test_dice_loss_worst_prediction(self):
"""Test that dice loss is 1 for worst possible predictions."""
# Test binary case
logits = jnp.array([-10.0, 10.0]) # Worst prediction for class 0
labels = jnp.array([1.0, 0.0])
loss = self.variant(_segmentation.dice_loss)(logits, labels)
self.assertAlmostEqual(loss, 1.0, delta=1e-6)
@chex.all_variants
def test_dice_ce_loss_weights(self):
"""Test that dice_ce_loss respects the dice_weight parameter."""
logits = jnp.array([1.0, -1.0])
labels = jnp.array([1.0, 0.0])
# Pure dice loss
dice_only = self.variant(_segmentation.dice_ce_loss)(
logits, labels, dice_weight=1.0)
expected_dice = self.variant(_segmentation.dice_loss)(logits, labels)
self.assertAlmostEqual(dice_only, expected_dice, delta=1e-6)
# Pure CE loss
ce_only = self.variant(_segmentation.dice_ce_loss)(
logits, labels, dice_weight=0.0)
expected_ce = jax.nn.sigmoid_cross_entropy_with_logits(logits, labels).mean()
self.assertAlmostEqual(ce_only, expected_ce, delta=1e-6)
# Mixed loss
mixed = self.variant(_segmentation.dice_ce_loss)(
logits, labels, dice_weight=0.5)
expected_mixed = 0.5 * expected_dice + 0.5 * expected_ce
self.assertAlmostEqual(mixed, expected_mixed, delta=1e-6)
@chex.all_variants
def test_dice_loss_with_mask(self):
"""Test that the where parameter correctly masks elements."""
logits = jnp.array([[1.0, -1.0], [1.0, -1.0]])
labels = jnp.array([[1.0, 0.0], [0.0, 1.0]])
mask = jnp.array([True, False])
# With mask
loss = self.variant(_segmentation.dice_loss)(
logits, labels, where=mask)
# Should be equal to just the first example
expected = _segmentation.dice_loss(logits[0], labels[0])
self.assertAlmostEqual(loss, expected, delta=1e-6)
@chex.all_variants
def test_dice_ce_loss_gradients(self):
"""Test that gradients exist and are finite."""
logits = jnp.array([0.5, -0.5])
labels = jnp.array([1.0, 0.0])
def loss_fn(logits):
return _segmentation.dice_ce_loss(logits, labels, dice_weight=0.5)
grad_fn = jax.grad(loss_fn)
grads = self.variant(grad_fn)(logits)
# Check gradients are not NaN or Inf
self.assertTrue(jnp.all(jnp.isfinite(grads)))
self.assertNotAlmostEqual(jnp.sum(jnp.abs(grads)), 0.0, delta=1e-6)
if __name__ == "__main__":
absltest.main()
I'd like to propose adding Dice Loss and Dice + Cross Entropy (CE) Loss functions to Optax. These are widely used in segmentation tasks, particularly in medical imaging and increasingly in general computer vision workflows.
🔹 Why it's useful:
Dice-based losses are a staple in segmentation benchmarks.
Many JAX users working with segmentation tasks implement these externally — having them in Optax improves consistency and adoption.
🔹 What's included:
dice_loss and dice_ce_loss with:
Binary and multi-class support
Optional where masking
Gradient safety
Unit tests covering:
Perfect/worst predictions
Masking behaviour
Dice/CE combination tuning
Gradient checks
I’ve prepared an initial implementation (code + tests) in the style of Optax and would be happy to polish it for final inclusion. Would it be okay to open a PR for this?
Looking forward to your thoughts!
Best, @YashSachdeva
@aymuos15 yes, thanks! Feel free to start a PR and tag me as the reviewer!
@YashSachdeva I don't understand, @aymuos15 already proposed the extension.
Thanks a lot @rdyro. Will get on with it then.