pytorch-unet icon indicating copy to clipboard operation
pytorch-unet copied to clipboard

Can dice and bce loss work on a multi-class task?

Open Dr-Cube opened this issue 5 years ago • 5 comments

Thanks for the great implementation code.

I am confusing about the loss function. As far as I can see dice and bce are both used in binary-class task. Can they work well on multi-class task? From your code I can see the losses work ok, but what about bigger data set.

I tried F.cross_entropy(), but it gives me this: RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [36, 4, 224, 224]. Could you please tell me whats wrong? thx

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)
    target_long = target.type(torch.LongTensor)
    ce = F.cross_entropy(pred, target_long.cuda())

    # pred = F.sigmoid(pred)
    pred = torch.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss

Dr-Cube avatar Dec 30 '19 02:12 Dr-Cube

You're right. The current code uses BCE, so each pixel can have multiple classes i.e. multi-labels.

To make it a single class for each pixel i.e. multi-class, you can use CE. I think you need to use reshape/view to 2d.

usuyama avatar May 19 '20 06:05 usuyama

Why are the metrics multiplied by the batch size, added cumulatively and then divided by the total number of samples during printing?

It seems like this will print a scaled version of the average metric value (from print_metrics), depending on the batch size. Please correct me if I'm wrong.

ckolluru avatar May 24 '20 16:05 ckolluru

Hi, I am having a problem dealing with a multi-class task where dimensions are like these:

MASK TARGET: torch.Size([4, 1, 600, 900, 3])
OUTPUT: torch.Size([4, 5, 600, 900]

@ckolluru Have you created your loss function for multiclass already?

sarmientoj24 avatar Jun 12 '20 14:06 sarmientoj24

Why are the metrics multiplied by the batch size, added cumulatively and then divided by the total number of samples during printing?

The reason is some batches (i.e. the last batch) may have fewer training examples than all the other batches. Dividing the product metric * batch_size by total_samples is a better estimate of a used metric for a complete epoch. Skimming through the example Training a Classifier from PyTorch tutorials reveals that the same strategy was used in section "4. Train the network" is the important one where the statistics were printed.

It seems like this will print a scaled version of the average metric value (from print_metrics), depending on the batch size. Please correct me if I'm wrong.

That's true! You may get a better insight into this topic by reading How is the loss for an epoch reported?

shahzad-ali avatar Sep 29 '20 08:09 shahzad-ali

If you want to use cross entropy make sure you're not applying sigmoid function beforehand.

AlphonsG avatar Mar 24 '23 09:03 AlphonsG