Face_Pytorch icon indicating copy to clipboard operation
Face_Pytorch copied to clipboard

How you using ArcMarginProduct.py as loss

Open jaideep11061982 opened this issue 6 years ago • 5 comments

hi how are you using arcface as loss. I could only see Cross ENtropy loss in your implementation.

jaideep11061982 avatar Mar 09 '19 14:03 jaideep11061982

Arcface is implemented in margin/ArcMarginProduct.py

wujiyang avatar Mar 10 '19 02:03 wujiyang

thanks.. what is thought process behind passing output of above to CrossEntropy() output = margin(raw_logits, label) total_loss = criterion(output, label)

jaideep11061982 avatar Mar 10 '19 17:03 jaideep11061982

Well I think you didn't totally understand the principle of margin-based algorithms (sphereface, cosface, arcface). They all use Cross Entropy as loss function, the difference between them lies in the softmax operation, different algorithms adopt different margins for the final classification vector.

wujiyang avatar Mar 11 '19 00:03 wujiyang

in Arcface margin you are doing random init of self.weights . In every batch run how would self.weights get updated as they not part of your net like resnet,mobilnet etc. so every time you use only random weight you got first time during init.

for data in trainloader:
           img, label = data[0].to(device), data[1].to(device)
           optimizer_ft.zero_grad()

           raw_logits = net(img)
           output = margin(raw_logits, label)
           total_loss = criterion(output, label)
           total_loss.backward()
  1. secondly BatchNorm should avoid the need for normalizing the in_features for last layer ?
  2. how to decide on hyper param like s,th,m etc
  3. does it takes many iteration for it to converge .. i started with loss as 15... it converging very slowly.

jaideep11061982 avatar Mar 12 '19 14:03 jaideep11061982

If you want the arcface weights to be updated after each backpropagation, you have to add them to the initialisation of your optimizer. Sometime like optmizer = optim.Adam(list(model.parameters()) + list(arcface.parameters(), lr = your_learning_rate)

wfcb-85 avatar Jun 26 '19 21:06 wfcb-85