GeneralizedWassersteinDiceLoss icon indicating copy to clipboard operation
GeneralizedWassersteinDiceLoss copied to clipboard

Official implementation of the Generalized Wasserstein Dice Loss in PyTorch

Generalized Wasserstein Dice Loss

The Generalized Wasserstein Dice Loss (GWDL) is a loss function to train deep neural networks for applications in medical image multi-class segmentation.

The GWDL is a generalization of the Dice loss and the Generalized Dice loss that can tackle hierarchical classes and can take advantage of known relationships between classes.

Installation

pip install git+https://github.com/LucasFidon/GeneralizedWassersteinDiceLoss.git

Example

import torch
import numpy as np
from generalized_wasserstein_dice_loss.loss import GeneralizedWassersteinDiceLoss

# Example with 3 classes (including the background: label 0).
# The distance between the background (class 0) and the other classes is the maximum, equal to 1.
# The distance between class 1 and class 2 is 0.5.
dist_mat = np.array([
    [0., 1., 1.],
    [1., 0., 0.5],
    [1., 0.5, 0.]
])
wass_loss = GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)
# 1D prediction; shape: batch size, n class, n elements
pred = torch.tensor([[[1, 0], [0, 1], [0, 0]]], dtype=torch.float32).cuda()
# !D ground truth; shape: batch size, n elements 
grnd = torch.tensor([[0, 2]], dtype=torch.int64).cuda()
wass_loss(pred, grnd)

How to cite

If you use the Generalized Wasserstein Dice Loss in your work, please cite

BibTeX:

@inproceedings{fidon2017generalised,
  title={Generalised {W}asserstein dice score for imbalanced multi-class segmentation using holistic convolutional networks},
  author={Fidon, Lucas and Li, Wenqi and Garcia-Peraza-Herrera, Luis C and Ekanayake, Jinendra and Kitchen, Neil and Ourselin, S{\'e}bastien and Vercauteren, Tom},
  booktitle={International MICCAI Brainlesion Workshop},
  pages={64--76},
  year={2017},
  organization={Springer}
}

Applications of the Generalized Wasserstein Dice loss

For more examples of applications of the generalized Wasserstein Dice loss and how to define the distance matrix, you can look at:

If you find more papers using the generalized Wasserstein Dice loss please let me know :)