MatchingNetworks icon indicating copy to clipboard operation
MatchingNetworks copied to clipboard

Cosine distance calculation problem

Open ouc-lq opened this issue 4 years ago • 2 comments

In the source code, the author calculates the cosine distance as follows.

        sum_support = torch.sum(torch.pow(support_image, 2), 1) 
        support_manitude = sum_support.clamp(eps, float("inf")).rsqrt() 
        dot_product = input_image.unsqueeze(1).bmm(support_image.unsqueeze(2)).squeeze()
        cosine_similarity = dot_product * support_manitude * input_manitude
        similarities.append(cosine_similarity)

But in my opinion, the right the cosine distance should be calculated as follows.

        sum_support = torch.sum(torch.pow(support_image, 2), 1) 
        support_manitude = sum_support.clamp(eps, float("inf")).rsqrt() 
        sum_input = torch.sum(torch.pow(input_image, 2), 1)
        input_manitude = sum_input.clamp(eps, float("inf")).rsqrt()
        dot_product = input_image.unsqueeze(1).bmm(support_image.unsqueeze(2)).squeeze()
        cosine_similarity = dot_product * support_manitude * input_manitude
        similarities.append(cosine_similarity)

Am i right? If not, what is the mistake?

ouc-lq avatar Apr 11 '21 15:04 ouc-lq

@gitabcworld @jacklanchantin

ouc-lq avatar Apr 11 '21 15:04 ouc-lq

we could calculate cosine similarity with following succinct code: cosine_similarity = F.cosine_similarity(support_image, target_set) where F is imported as torch.nn.functional

502dxceit avatar Jun 27 '23 08:06 502dxceit