Prototypical-Networks-for-Few-shot-Learning-PyTorch icon indicating copy to clipboard operation
Prototypical-Networks-for-Few-shot-Learning-PyTorch copied to clipboard

error at accuracy

Open hahmyg opened this issue 5 years ago • 0 comments

Dear author

Thank you for your carefully written code. I re-use your some codes, and I found out the error

please check the line 84 in prototypical_loss.py I think y_hat should be sequeezed with squeeze()

y_hat and target_inds.squeeze() look like:

y_hat = torch.tensor([[0],[1],[2],[0],[4]])
target_inds.squeeze() = torch.tensor([0, 1, 2, 3, 4])

In this case,

y_hat.eq(target_inds.squeeze()).float()

tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.]])

In this case, accuracy is 0.2

It should be tensor([1., 1., 1., 0., 1.]). In this case, accuracy is 0.8

hahmyg avatar Jan 02 '20 09:01 hahmyg