torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Add Kendall's Rank Correlation

Open amorehead opened this issue 3 years ago • 4 comments

🚀 Feature

A Metric that computes Kendall's Rank Correlation (https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient).

Motivation

This metric would be useful for many regressive tasks.

Pitch

Add a new Metric that computes Kendall's Rank Correlation - specifically its tau values - possibly using SciPy's implementation as a reference (https://web.archive.org/web/20181008171919/https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.kendalltau.html).

Alternatives

One currently can compute this quantity using SciPy. However, it would be desirable to compute it directly using PyTorch.

Additional context

Whoever implements this new Metric deserves at least a (virtual) high-five! :hand:

Note that the tau values associated with this metric can be between -1 and 1, with 1 indicating a perfect relationship between a pair of input variables. However, according to this source (https://www.statisticshowto.com/kendalls-tau/), any negative signs on tau can safely be discarded.

amorehead avatar Sep 06 '22 18:09 amorehead

For those interested in a quick solution until someone implements a native PyTorch approach, the following custom Metric may suffice.

import scipy
import torch
from torchmetrics import Metric

class KendallTau(Metric):
    is_differentiable: Optional[bool] = False
    higher_is_better: Optional[bool] = True
    full_state_update: Optional[bool] = True

    def __init__(self):
        super().__init__()
        self.add_state("total_tau", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape

        self.total_tau += scipy.stats.kendalltau(preds.cpu(), target.cpu()).correlation
        self.total += target.numel()

    def compute(self):
        return self.total_tau / self.total

amorehead avatar Sep 16 '22 20:09 amorehead

@amorehead It would be a really great addition to torchmetrics. Would you wanna try to add this in a PR, or you don't have enough time right now? 🐰

stancld avatar Oct 05 '22 10:10 stancld

Hi, @stancld. My availability is quite limited at the moment. Are you referring to a full PyTorch-tensors implementation of this metric, or instead one that transfers label and prediction tensors to the CPU and then computes this metric value using scikit-learn?

amorehead avatar Oct 09 '22 17:10 amorehead

I meant full PyTorch implementation, would definitely refrain from using scipy implementation. Maybe I will have some spare time to look at that at the weekend, but cannot promise anything right now :]

stancld avatar Oct 12 '22 16:10 stancld

@amorehead Just FYI, we've prepared a MR adding the metric. It's almost done, we just need to figure out one failing mypy test and then we'll be ready to merge. We'll be happy if you will then give a try 🤗

stancld avatar Oct 27 '22 18:10 stancld

@stancld, will do! I greatly appreciate your time and help here, and I'm sure others will as well!

amorehead avatar Nov 04 '22 16:11 amorehead