nnUNet icon indicating copy to clipboard operation
nnUNet copied to clipboard

Pseudo dice while training and validation

Open Body123 opened this issue 10 months ago • 8 comments

I want to ask while training in the logs I found as an example : 2024-04-23 15:22:40.711169: Epoch 195 2024-04-23 15:22:40.724555: Current learning rate: 0.00036 2024-04-23 15:23:41.803760: train_loss -0.7526 2024-04-23 15:23:41.832845: val_loss -0.771 2024-04-23 15:23:41.869548: Pseudo dice [0.8399] 2024-04-23 15:23:41.895059: Epoch time: 61.1 s 2024-04-23 15:23:41.911069: Yayy! New best EMA pseudo Dice: 0.8196 but when calculating after the number of epochs end in this fold the number of pseudo dice is less than the above : 2024-04-23 15:29:06.880623: Mean Validation Dice: 0.6980713938291729 so I want to ask what is this number represents 2024-04-23 15:23:41.869548: Pseudo dice [0.8399] and what this represent : 2024-04-23 15:23:41.911069: Yayy! New best EMA pseudo Dice: 0.8196

Body123 avatar Apr 24 '24 09:04 Body123

Here is my understanding. The nnUNet is trained based on patches, so the actual input data fed into the network are patches sampled from the original image. The dice metric during training is also calculated based on these patches. After training, nnUNet performs sliding window inference by traversing the image. The patches obtained from this sliding window sampling may differ from the sampling strategy used during training, resulting in a potential decrease in the Dice score.

Chasel-Chen avatar Apr 25 '24 05:04 Chasel-Chen

Exactly as @Chasel-Chen explains it! The pseudo dice is just to see the training progress and to determine whether overfitting is a problem (green line goes down = bad!). It is not directly related to the actual validation dice! How strongly these two agree depends on the dataset

FabianIsensee avatar Apr 25 '24 06:04 FabianIsensee

Exactly as @Chasel-Chen explains it! The pseudo dice is just to see the training progress and to determine whether overfitting is a problem (green line goes down = bad!). It is not directly related to the actual validation dice! How strongly these two agree depends on the dataset

Hi,Fabian I also have a similar confusion. In certain segmentation tasks, patch-based metrics perform exceptionally well during training. However, when performing inference on the original images, there is a significant difference in Dice scores, even when evaluated on the training set. I have attempted to modify the sampling strategy (such as shifting the generated bounding boxes), but without much success. Are there any suggestions for further experimentation and improvement, such as adjusting patch size, oversample rate, or any other methods? Thank you very much.

Chasel-Chen avatar Apr 25 '24 07:04 Chasel-Chen

There is no way of 'fixing' this, please just take it as it is! It is not supposed to be the same as the validation dice. Running actual validations would take too long

FabianIsensee avatar Apr 25 '24 07:04 FabianIsensee

but what is the difference between these: 2024-04-23 15:23:41.869548: Pseudo dice [0.8399] 2024-04-23 15:23:41.911069: Yayy! New best EMA pseudo Dice: 0.8196 each one of them is related to training and which one related to validation as those happen in the same epoch

Body123 avatar Apr 26 '24 09:04 Body123

EMA stands for exponential moving average. This tracks the pseudo dice following a updated = old * alpha + (1 - alpha) * new formula. This is the smoothed green line in the progress.png plot

FabianIsensee avatar Apr 26 '24 11:04 FabianIsensee

Okay thanks alot, my last question, and sorry for many questions: those numbers are based on taking some random patch from the training set during epoch not the whole image or they are from validation set during the epoch : 2024-04-23 15:23:41.869548: Pseudo dice [0.8399] 2024-04-23 15:23:41.911069: Yayy! New best EMA pseudo Dice: 0.8196

Body123 avatar Apr 26 '24 14:04 Body123

they are random patches from the validation set, sampled according to the same rules as the training patches. Dice is computed over all sampled patches (pretending all patches together are one image) and not for each patch

FabianIsensee avatar Apr 26 '24 14:04 FabianIsensee

image Can anyone tell me what is wrong with my training? Any help is much appreciated. @Body123 @Chasel-Chen @FabianIsensee

mehnaz1985 avatar May 07 '24 03:05 mehnaz1985

Seems like

  1. your epoch time is really slow, see https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/benchmarking.md
  2. you seem to suffer from some overfitting
  3. one class is not learned properly

What could concretely be the problem I cant say just from what you shared

FabianIsensee avatar May 22 '24 07:05 FabianIsensee

Thank you for your response, @FabianIsensee . Let me explain my task and the dataset. I am attempting to segment recurrent and non-recurrent areas in the peritumoral region. There are three labels, including background, and I have seven channels (t1, t2, t1c, flair, dti_fa, dti_ra, and dti_md). FYI, the recurrent area is tiny compared to the non-recurrent area, possibly a class imbalance problem. I was thinking of giving more weight to the minority class in the loss function, but I was confused about where to change. There are several losses! and also need to call that loss in the nnUNetTrainer.py? pls enlighten me in this matter.

mehnaz1985 avatar May 22 '24 15:05 mehnaz1985

now another problem arose, image

mehnaz1985 avatar May 23 '24 12:05 mehnaz1985

It looks like you created a class that has the same name as the nnUNetTrainer but does not inherit from it? When making changes, please remember to create trainer classes with unique names and make them inherit from nnUNetTrainer.

It seems like one of your classes is very difficult to learn or was annotated too inconsistently. Giving it more weight will not solve the problem. Looking at your training loss it seems like the model can learn the class on the training data but it fails to generalize to the validation set

FabianIsensee avatar May 28 '24 09:05 FabianIsensee