torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Add Earth Mover's Distance metric

Open zilto opened this issue 2 years ago • 7 comments

🚀 Feature

Add support for the "Earth Mover's Distance" (EMD) or "Wasserstein metric".

Motivation

The EMD is a distance metric between two probability distributions. It gained a lot of traction in GANs, but also in semantic search and ordinal classification/regression.

Pitch and Design decisions

I already reimplemented this code in torchmetrics for my own ordinal regression problem.

(gist: https://gist.github.com/zilto/99a9b99db9345a6877fa4f72dc34207f)

def ordinal_softmax(x):
    """"""   
    log_probs = F.logsigmoid(x).to("cuda")
    cum_probs = torch.cat(
        (
            torch.ones(x.shape[0],1,dtype=torch.float32).to("cuda"),
            torch.exp(torch.cumsum(log_probs, dim = 1)),
            torch.zeros(x.shape[0],1,dtype=torch.float32).to("cuda")
        ),
        dim=1
    )    
    return cum_probs[:,0:-1] - cum_probs[:,1:]


def _earth_movers_distance_update(
        logits: torch.Tensor,
        target: torch.Tensor,
        num_classes: int
    ) -> tuple[torch.Tensor, int]:
    if len(target.shape) == 1:
        target = target.reshape((-1, 1))
        
    n_obs = logits.shape[0]

    y_dist = torch.abs(target.repeat(1, num_classes) - torch.arange(0, num_classes).repeat(n_obs, 1).to("cuda"))
    cumulative_probs = ordinal_softmax(logits)

    earth_movers_dist = torch.sum(torch.mul(cumulative_probs, y_dist)) #, 1)
    return earth_movers_dist, n_obs


def _earth_movers_distance_reduce(
        earth_movers_dist: torch.Tensor,
        n_obs: torch.Tensor,
        reduction: Optional[Literal["mean", "sum"]] = "mean",
    ) -> torch.Tensor:
    allowed_reduction = ("mean", "sum")
    if reduction not in allowed_reduction:
        raise ValueError(f"Argument `reduction`needs to be of the following: {allowed_reduction}")

    if reduction == 'mean':
        earth_movers_dist /= n_obs
    elif reduction == 'sum':
        pass
    
    return earth_movers_dist.float()

def _earth_movers_distance_compute(
        earth_movers_dist: torch.Tensor,
        n_obs: torch.Tensor,
        reduction: Optional[Literal["mean", "sum"]] = "mean",
    ) -> torch.Tensor:
    return _earth_movers_distance_reduce(earth_movers_dist, n_obs, reduction)


class EarthMoversDistance(Metric):
    higher_is_better = False
    plot_options: dict = {"lower_bound": 0.0, "upper_bound": 1.0, "legend_name": "Label"}

    def __init__(
            self,
            num_classes: int,
            reduction: Optional[Literal["mean", "sum"]],
            compute_on_step: bool = True,
            dist_sync_on_step: bool = False,
            process_group: Optional[Any] = None,
        ):
        super().__init__(
            compute_on_step=compute_on_step,
            dist_sync_on_step=dist_sync_on_step,
            process_group=process_group,
        )
        self.num_classes = num_classes
        self.reduction = reduction

        self.add_state("earth_movers_distance", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0, dtype=torch.int32), dist_reduce_fx="sum")
        
    def update(self, logits: torch.Tensor, target: torch.Tensor):
        """Update state with predictions and targets."""
        earth_movers_distance, n_obs = _earth_movers_distance_update(logits, target, num_classes=self.num_classes)
        earth_movers_distance = _earth_movers_distance_reduce(earth_movers_distance, n_obs, reduction=self.reduction)

        self.earth_movers_distance += earth_movers_distance
        self.total += n_obs

    def compute(self)-> torch.Tensor:
        return _earth_movers_distance_compute(self.earth_movers_distance, self.total, reduction=self.reduction)

I wanted to discuss a few points before working on a PR:

  • In my current implementation, the metric receives the true label with shape (n_examples, ). However, the metric needs to manipulate a tensor of shape (n_examples, 1). Should the reshaping operation be done inside or outside the metric?
  • In my current implementation, the metric receives the logits from the last layer of the model with shape (n_examples, num_classes) and I apply a cumulative softmax activation on the logits. Should the activation be done inside or outside the metric?
  • As mentioned, the metric receives the logits from the last layer of the model with shape (n_examples, num_classes). I added an explicit num_classes argument for transparency, but it is not necessary and could be inferred from the shape of the array. Should this argument be included?
  • In _earth_movers_distance_update(), I create a new Tensor object, but I couldn't figure how to write it such that it automatically gets created on the right device (cpu, cuda, etc.) using PyTorch Lightning.
  • The full metric results is an array of shape (n_examples, num_classes), which could be informative and useful in some use cases. I don't know the best approach to keep flexibility in "reducing" the metric. My current implementation sums across both axis and returns n_examples to be able to reduce further using the mean in a distributed setting (sum of all nodes / sum of n_examples)

zilto avatar Feb 21 '23 00:02 zilto

ps: Couldn't figure out how to add the gist

zilto avatar Feb 21 '23 00:02 zilto

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Feb 21 '23 00:02 github-actions[bot]

Hey @zilto , thanks for the issue and sorry for the long time to reply.

The code already looks pretty good to me.

Regarding your questions: 1.) I'd do the reshape inside the metric, as this is then consistent to the behavior of pytorch's cross-entropy. 2.) Same holds for the cumulative cross-entropy. IMO this is not part of the model, but of the metric. 3.) I'd have it rather explicit and add a check that the passed logits match this. 4.) you could probably just move it to target.device as you directly need it for calculations with target. Note: ideally, you already add this to the torch.arange(...) call with torch.arange(..., device=target.device) as this then creates the tensor directly on the correct device rather than having to move it later on 5.) Ideally the reduce would return a single value. With the class-based interface, we already sync the states before calling compute, so there should be no need to reduce again in a distributed-aware fashion afterwards.

For everything else: I propose you just create a draft-pr so that we can actually discuss this on a more code-oriented basis. Feel free to directly ping me there :)

justusschock avatar Feb 27 '23 13:02 justusschock

Thanks for the feedback @justusschock! I did some good progress on cleaning up the code, decoupling components, and generalizing the approach.

On another note, I found out about the scipy implementation of the metric yesterday! The source code is not too long and only depends on numpy. Would a direct port make sense?

I'll be happy to open a PR!

zilto avatar Mar 05 '23 00:03 zilto

@zilto The scipy implementation might be great as a reference metric for testing purposes. However, for torchmetrics implementation, the paradigm is to have our own implementation unless any license forbids it.

stancld avatar Mar 07 '23 09:03 stancld

Hi @zilto, just checking, do you have any updates with the PR? Or do you have any blockers I can help with? I think EMD is fairly complicated to implement from scratch, so having a dependency here might be an option as well.

stancld avatar Apr 09 '23 10:04 stancld

Hi @stancld! Sorry for the delay, life got in the way...

Regarding the implementation, I had the opportunity to work on it, but was hitting different design questions. I will open a draft PR during the week to make the conversation more concrete.

My current observations:

  • EMD is conceptually most closely related to KL divergence because they both describe a distance between 2 distributions. I would recreate a non-opiniated API with inputs p and q, and add it under Regression metrics
  • My initial use case was ordinal regression with ordinal softmax, but it is not a commonality across applications. It could make sense to handle activation outside the metric, pass the activation module as argument, or use an argument to specify the activation

zilto avatar Apr 09 '23 14:04 zilto