torchmetrics
torchmetrics copied to clipboard
Add Earth Mover's Distance metric
🚀 Feature
Add support for the "Earth Mover's Distance" (EMD) or "Wasserstein metric".
Motivation
The EMD is a distance metric between two probability distributions. It gained a lot of traction in GANs, but also in semantic search and ordinal classification/regression.
Pitch and Design decisions
I already reimplemented this code in torchmetrics for my own ordinal regression problem.
(gist: https://gist.github.com/zilto/99a9b99db9345a6877fa4f72dc34207f)
def ordinal_softmax(x):
""""""
log_probs = F.logsigmoid(x).to("cuda")
cum_probs = torch.cat(
(
torch.ones(x.shape[0],1,dtype=torch.float32).to("cuda"),
torch.exp(torch.cumsum(log_probs, dim = 1)),
torch.zeros(x.shape[0],1,dtype=torch.float32).to("cuda")
),
dim=1
)
return cum_probs[:,0:-1] - cum_probs[:,1:]
def _earth_movers_distance_update(
logits: torch.Tensor,
target: torch.Tensor,
num_classes: int
) -> tuple[torch.Tensor, int]:
if len(target.shape) == 1:
target = target.reshape((-1, 1))
n_obs = logits.shape[0]
y_dist = torch.abs(target.repeat(1, num_classes) - torch.arange(0, num_classes).repeat(n_obs, 1).to("cuda"))
cumulative_probs = ordinal_softmax(logits)
earth_movers_dist = torch.sum(torch.mul(cumulative_probs, y_dist)) #, 1)
return earth_movers_dist, n_obs
def _earth_movers_distance_reduce(
earth_movers_dist: torch.Tensor,
n_obs: torch.Tensor,
reduction: Optional[Literal["mean", "sum"]] = "mean",
) -> torch.Tensor:
allowed_reduction = ("mean", "sum")
if reduction not in allowed_reduction:
raise ValueError(f"Argument `reduction`needs to be of the following: {allowed_reduction}")
if reduction == 'mean':
earth_movers_dist /= n_obs
elif reduction == 'sum':
pass
return earth_movers_dist.float()
def _earth_movers_distance_compute(
earth_movers_dist: torch.Tensor,
n_obs: torch.Tensor,
reduction: Optional[Literal["mean", "sum"]] = "mean",
) -> torch.Tensor:
return _earth_movers_distance_reduce(earth_movers_dist, n_obs, reduction)
class EarthMoversDistance(Metric):
higher_is_better = False
plot_options: dict = {"lower_bound": 0.0, "upper_bound": 1.0, "legend_name": "Label"}
def __init__(
self,
num_classes: int,
reduction: Optional[Literal["mean", "sum"]],
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)
self.num_classes = num_classes
self.reduction = reduction
self.add_state("earth_movers_distance", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0, dtype=torch.int32), dist_reduce_fx="sum")
def update(self, logits: torch.Tensor, target: torch.Tensor):
"""Update state with predictions and targets."""
earth_movers_distance, n_obs = _earth_movers_distance_update(logits, target, num_classes=self.num_classes)
earth_movers_distance = _earth_movers_distance_reduce(earth_movers_distance, n_obs, reduction=self.reduction)
self.earth_movers_distance += earth_movers_distance
self.total += n_obs
def compute(self)-> torch.Tensor:
return _earth_movers_distance_compute(self.earth_movers_distance, self.total, reduction=self.reduction)
I wanted to discuss a few points before working on a PR:
- In my current implementation, the metric receives the true label with shape
(n_examples, )
. However, the metric needs to manipulate a tensor of shape(n_examples, 1)
. Should the reshaping operation be done inside or outside the metric? - In my current implementation, the metric receives the logits from the last layer of the model with shape
(n_examples, num_classes)
and I apply a cumulative softmax activation on the logits. Should the activation be done inside or outside the metric? - As mentioned, the metric receives the logits from the last layer of the model with shape
(n_examples, num_classes)
. I added an explicitnum_classes
argument for transparency, but it is not necessary and could be inferred from the shape of the array. Should this argument be included? - In
_earth_movers_distance_update()
, I create a newTensor
object, but I couldn't figure how to write it such that it automatically gets created on the right device (cpu, cuda, etc.) using PyTorch Lightning. - The full metric results is an array of shape
(n_examples, num_classes)
, which could be informative and useful in some use cases. I don't know the best approach to keep flexibility in "reducing" the metric. My current implementation sums across both axis and returnsn_examples
to be able to reduce further using the mean in a distributed setting (sum of all nodes / sum ofn_examples
)
ps: Couldn't figure out how to add the gist
Hi! thanks for your contribution!, great first issue!
Hey @zilto , thanks for the issue and sorry for the long time to reply.
The code already looks pretty good to me.
Regarding your questions:
1.) I'd do the reshape inside the metric, as this is then consistent to the behavior of pytorch's cross-entropy.
2.) Same holds for the cumulative cross-entropy. IMO this is not part of the model, but of the metric.
3.) I'd have it rather explicit and add a check that the passed logits match this.
4.) you could probably just move it to target.device
as you directly need it for calculations with target
.
Note: ideally, you already add this to the torch.arange(...)
call with torch.arange(..., device=target.device)
as this then creates the tensor directly on the correct device rather than having to move it later on
5.) Ideally the reduce
would return a single value. With the class-based interface, we already sync the states before calling compute, so there should be no need to reduce again in a distributed-aware fashion afterwards.
For everything else: I propose you just create a draft-pr so that we can actually discuss this on a more code-oriented basis. Feel free to directly ping me there :)
Thanks for the feedback @justusschock! I did some good progress on cleaning up the code, decoupling components, and generalizing the approach.
On another note, I found out about the scipy implementation of the metric yesterday! The source code is not too long and only depends on numpy. Would a direct port make sense?
I'll be happy to open a PR!
@zilto The scipy
implementation might be great as a reference metric for testing purposes. However, for torchmetrics
implementation, the paradigm is to have our own implementation unless any license forbids it.
Hi @zilto, just checking, do you have any updates with the PR? Or do you have any blockers I can help with? I think EMD
is fairly complicated to implement from scratch, so having a dependency here might be an option as well.
Hi @stancld! Sorry for the delay, life got in the way...
Regarding the implementation, I had the opportunity to work on it, but was hitting different design questions. I will open a draft PR during the week to make the conversation more concrete.
My current observations:
- EMD is conceptually most closely related to KL divergence because they both describe a distance between 2 distributions. I would recreate a non-opiniated API with inputs p and q, and add it under
Regression
metrics - My initial use case was ordinal regression with ordinal softmax, but it is not a commonality across applications. It could make sense to handle activation outside the metric, pass the activation module as argument, or use an argument to specify the activation