easy-few-shot-learning icon indicating copy to clipboard operation
easy-few-shot-learning copied to clipboard

cosine_distance_to_prototypes() and l2_distance_to_prototypes() are falsely named

Open Y-T-G opened this issue 11 months ago • 2 comments

If the model is returning cosine distance:

https://github.com/sicara/easy-few-shot-learning/blob/8422b97155f6edd506e99fd5b83362ee36865f1e/easyfsl/methods/simple_shot.py#L29

Does that mean the lower the better?

Y-T-G avatar Mar 19 '24 14:03 Y-T-G

From the code and docstring of the cosine_distance_to_prototypes() method:

    def cosine_distance_to_prototypes(self, samples) -> Tensor:
        """
        Compute prediction logits from their cosine distance to support set prototypes.
        Args:
            samples: features of the items to classify of shape (n_samples, feature_dimension)
        Returns:
            prediction logits of shape (n_samples, n_classes)
        """
        return (
            nn.functional.normalize(samples, dim=1)
            @ nn.functional.normalize(self.prototypes, dim=1).T
        )

The method actually doesn't return cosine distances but predictions logits equal to the cosine similarity, so the higher is actually the better.

Same logic with the other available "distance", which is actually logits as the opposite of the distance:

    def l2_distance_to_prototypes(self, samples: Tensor) -> Tensor:
        """
        Compute prediction logits from their euclidean distance to support set prototypes.
        Args:
            samples: features of the items to classify of shape (n_samples, feature_dimension)
        Returns:
            prediction logits of shape (n_samples, n_classes)
        """
        return -torch.cdist(samples, self.prototypes)

Calling the methods cosine_distance_to_prototypes() and l2_distance_to_prototypes() is a misleading naming. I am marking this as a much needed enhancement to the library.

ebennequin avatar Mar 19 '24 16:03 ebennequin

I see. That makes it clear. Thanks.

Y-T-G avatar Mar 19 '24 17:03 Y-T-G