Deep_Hierarchical_Classification
Deep_Hierarchical_Classification copied to clipboard
Gradient unable to backprop if we use argmax or torch.where
Hi,
Correct me if I am wrong but in the code snippet to calculate D_l for the dependency loss
(https://github.com/Ugenteraan/Deep_Hierarchical_Classification/blob/e4f20ae51a2daabfc1c01f6fdab778ef31cc7617/model/hierarchical_loss.py#L65), argmax is non-differentiable, thus the gradient wrt to dloss won't be propagated back to predictions
variables, and subsequently to the parameters in the neural net, that means the model won't be able to learn from the dloss penalty. I have run this loss on my NLP project and the way the parameters updated are the same without any value of beta, which led me to this theory. Can you help me check this one out?
Also in check_hierarchy
function, a new Float Tensor is defined, thus I don't think it will be registered in the computational graph for backprop
+1 Even I am facing the same issue! I tried to figure out what was happening but clearly we're getting wrong argmax predictictions for prev_lvl_pred
Hey @anhquan0412 and @VipanchiRKatthula, apologies for the late response. After having another look, I think you are right. Thanks for pointing it out. I can't work on fixing it as of now however as I'm swamped with other works. Perhaps sometime soon. In the meantime, maybe you could look into this.
https://discuss.pytorch.org/t/differentiable-argmax/33020