Focal-Loss-Pytorch
Focal-Loss-Pytorch copied to clipboard
代码bug
self.alpha = self.alpha.gather(0,labels.view(-1))
因为上一阶段self.alpha被赋值后self.alpha数值变了。gather后的结果就会有问题。 另外如果这样写,样本维度变化,如果后一批样本比前一批样本维度大,会报错 :Invalid index in gather 建议改成 alpha = self.alpha.gather(0,labels.view(-1)) ... loss = torch.mul(alpha, loss.t())