nnUNet icon indicating copy to clipboard operation
nnUNet copied to clipboard

Adding weights to loss function,

Open Num13er-XIII opened this issue 2 years ago • 3 comments

Dear Fabian,

I was wondering if I can modify the weights in loss functions? I am training for 2 labels, and one of them is the important one.

In DC_and_CE_loss I have changed weight_dice and weight_ce to [0.95, 1], and also for DC_and_BCE_loss and DC_topk_loss

Is this the right approach?

Best regards Mahdi

Num13er-XIII avatar Nov 27 '23 10:11 Num13er-XIII

maybe not right.

CloudRobot avatar Dec 12 '23 06:12 CloudRobot

Any comments on this topic?

Num13er-XIII avatar Jan 17 '24 19:01 Num13er-XIII

Dear Mahdi, sorry for the late reply. Somehow the issue evaded my attention. weight_dice and weight_ce just weight the dice and CE loss in relation to each other.

The solution unfortunately is a bit more involved. You need to write your own trainer (see https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/extending_nnunet.md) which for the CE loss passes ce_kwargs with the paramter "weight" set to the desired values (see https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html). Unfortunately, this option is not available for the dice loss. Thus, you would have to extend the SoftDiceLoss class (https://github.com/MIC-DKFZ/nnUNet/blob/1b5a17daedb819b6d0be571598a1384a8a9befc5/nnunetv2/training/loss/dice.py#L8)

cheers Ole

dojoh avatar Feb 28 '24 10:02 dojoh