Prototypical-Networks-for-Few-shot-Learning-PyTorch
Prototypical-Networks-for-Few-shot-Learning-PyTorch copied to clipboard
error at accuracy
Dear author
Thank you for your carefully written code. I re-use your some codes, and I found out the error
please check the line 84 in prototypical_loss.py
I think y_hat should be sequeezed with squeeze()
y_hat and target_inds.squeeze() look like:
y_hat = torch.tensor([[0],[1],[2],[0],[4]])
target_inds.squeeze() = torch.tensor([0, 1, 2, 3, 4])
In this case,
y_hat.eq(target_inds.squeeze()).float()
tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1.]])
In this case, accuracy is 0.2
It should be tensor([1., 1., 1., 0., 1.]).
In this case, accuracy is 0.8