Focal-Loss-Pytorch icon indicating copy to clipboard operation
Focal-Loss-Pytorch copied to clipboard

代码bug

Open susoooon opened this issue 5 years ago • 0 comments

self.alpha = self.alpha.gather(0,labels.view(-1))

因为上一阶段self.alpha被赋值后self.alpha数值变了。gather后的结果就会有问题。 另外如果这样写,样本维度变化,如果后一批样本比前一批样本维度大,会报错 :Invalid index in gather 建议改成 alpha = self.alpha.gather(0,labels.view(-1)) ... loss = torch.mul(alpha, loss.t())

susoooon avatar Aug 26 '20 02:08 susoooon