Bug in CombinedMarginLoss implementation
Hi @anxiangsir,
Thanks for sharing your work.
I have a question about the forward pass in CombinedMarginLoss when running sop_vit_b_16.sh as an example. In this case, self.m1 = 1.0, self.m2 = 0.25, and self.m3 = 0.0, But I think with torch.no_grad(), the gradients won't be propagated correctly, right?
It also seems that the implementation of CombinedMarginLoss is adapted from the insightface repo, and its previous version (without torch.no_grad()) makes more sense here: https://github.com/deepinsight/insightface/commit/657ae30e41fc53641a50a68694009d0530d9f6b3
Some issues raised for the same query: https://github.com/deepinsight/insightface/issues/2218, https://github.com/deepinsight/insightface/issues/2255, https://github.com/deepinsight/insightface/issues/2309
Why do we need torch.no_grad() here?
Here we mainly adopted the implementation method of opensphere, and we found that this implementation method makes arcface more stable when training ViT.
@anxiangsir Thanks for getting back to me.
But it is not technically correct, right? The gradients won't be propagated back through those lines under torch.no_grad() (e.g., logits.arccos_()).
Also, I did a comparison experiment (w/ torch.no_grad() vs. w/o torch.no_grad() ) by running it on the SOP dataset using an A100 GPU. The performance w/o torch.no_grad() actually was better.
Any theory or math to support this change to add torch.no_grad()? This really confused me for a while. Thanks.