MatchingNetworks
MatchingNetworks copied to clipboard
Cosine distance calculation problem
In the source code, the author calculates the cosine distance as follows.
sum_support = torch.sum(torch.pow(support_image, 2), 1)
support_manitude = sum_support.clamp(eps, float("inf")).rsqrt()
dot_product = input_image.unsqueeze(1).bmm(support_image.unsqueeze(2)).squeeze()
cosine_similarity = dot_product * support_manitude * input_manitude
similarities.append(cosine_similarity)
But in my opinion, the right the cosine distance should be calculated as follows.
sum_support = torch.sum(torch.pow(support_image, 2), 1)
support_manitude = sum_support.clamp(eps, float("inf")).rsqrt()
sum_input = torch.sum(torch.pow(input_image, 2), 1)
input_manitude = sum_input.clamp(eps, float("inf")).rsqrt()
dot_product = input_image.unsqueeze(1).bmm(support_image.unsqueeze(2)).squeeze()
cosine_similarity = dot_product * support_manitude * input_manitude
similarities.append(cosine_similarity)
Am i right? If not, what is the mistake?
@gitabcworld @jacklanchantin
we could calculate cosine similarity with following succinct code: cosine_similarity = F.cosine_similarity(support_image, target_set) where F is imported as torch.nn.functional