easy-few-shot-learning
easy-few-shot-learning copied to clipboard
cosine_distance_to_prototypes() and l2_distance_to_prototypes() are falsely named
If the model is returning cosine distance:
https://github.com/sicara/easy-few-shot-learning/blob/8422b97155f6edd506e99fd5b83362ee36865f1e/easyfsl/methods/simple_shot.py#L29
Does that mean the lower the better?
From the code and docstring of the cosine_distance_to_prototypes()
method:
def cosine_distance_to_prototypes(self, samples) -> Tensor:
"""
Compute prediction logits from their cosine distance to support set prototypes.
Args:
samples: features of the items to classify of shape (n_samples, feature_dimension)
Returns:
prediction logits of shape (n_samples, n_classes)
"""
return (
nn.functional.normalize(samples, dim=1)
@ nn.functional.normalize(self.prototypes, dim=1).T
)
The method actually doesn't return cosine distances but predictions logits equal to the cosine similarity, so the higher is actually the better.
Same logic with the other available "distance", which is actually logits as the opposite of the distance:
def l2_distance_to_prototypes(self, samples: Tensor) -> Tensor:
"""
Compute prediction logits from their euclidean distance to support set prototypes.
Args:
samples: features of the items to classify of shape (n_samples, feature_dimension)
Returns:
prediction logits of shape (n_samples, n_classes)
"""
return -torch.cdist(samples, self.prototypes)
Calling the methods cosine_distance_to_prototypes()
and l2_distance_to_prototypes()
is a misleading naming. I am marking this as a much needed enhancement to the library.
I see. That makes it clear. Thanks.