mpl.pytorch
mpl.pytorch copied to clipboard
Loss goes to 0
Hi,
first of all thanks for publishing the source code, I want to try this new method in a segmentation problem I have, to avoid having to weigh the classes by hand.
When I try it, the loss approaches 0 very fast, going down to less than 2e-7 in a single epoch. I use it as follows:
max_pooling_loss = mpl.MaxPoolingLoss(ratio=0.3, p=1.7, reduce=True)
def mask_loss(pred,targ):
return max_pooling_loss(F.cross_entropy(pred, targ, reduce=False))
When I just use crossentropy, the model trains fine. Any ideas on what might be causing this?
Cheers, Johannes
Hi,
It's works for me with latest version of PyTorch. Which version of PyTorch do you use?
Best regards, BeS
Hi,
I use pytorch 0.4 and python 3.6. On pytorch 3.1, the code doesn't work due to the breaking changes regarding Variable/Tensor in 0.4.
Is the idea correct to first apply some loss_function and the use the maxpooling?
Cheers, Johannes
This code works fine for me:
class Loss:
def __init__(self):
self.criterion = nn.NLLLoss(reduce=False)
self.mpl = mpl.MaxPoolingLoss(ratio=0.3, p=2.0, reduce=True)
def __call__(self, pred, gt):
loss = self.criterion(pred, gt)
loss = self.mpl(loss)
return loss
Hmm, still can't make it work...
Tried a few more values for p and ratio, but didn't change anything.
Any ideas on how I could try to debug this? Intermediate values or gradients or something?
I only have 5 classes in the moment, the background is dominating by far (about 86 % percent averaged over the training set), could that be a reason why it doesnt work for me?
You can try to visualize loss before and after mpl and check that it works correct.
class Loss:
def __init__(self):
self.criterion = nn.NLLLoss(reduce=False)
self.mpl = mpl.MaxPoolingLoss(ratio=0.3, p=2.0, reduce=False)
def __call__(self, pred, gt):
loss = self.criterion(pred, gt)
cs_vis = loss.cpu().data[0].numpy()
_, max_v, _, _ = cv2.minMaxLoc(cs_vis)
cv2.imshow("cs_vis", cs_vis / max_v)
loss = self.mpl(loss)
mpl_vis = loss.cpu().data[0].numpy()
_, max_v, _, _ = cv2.minMaxLoc(mpl_vis)
cv2.imshow("mpl_vis", mpl_vis / max_v)
cv2.waitKey(0)
return loss.sum()