nnUNet
nnUNet copied to clipboard
Pseudo dice while training and validation
I want to ask while training in the logs I found as an example : 2024-04-23 15:22:40.711169: Epoch 195 2024-04-23 15:22:40.724555: Current learning rate: 0.00036 2024-04-23 15:23:41.803760: train_loss -0.7526 2024-04-23 15:23:41.832845: val_loss -0.771 2024-04-23 15:23:41.869548: Pseudo dice [0.8399] 2024-04-23 15:23:41.895059: Epoch time: 61.1 s 2024-04-23 15:23:41.911069: Yayy! New best EMA pseudo Dice: 0.8196 but when calculating after the number of epochs end in this fold the number of pseudo dice is less than the above : 2024-04-23 15:29:06.880623: Mean Validation Dice: 0.6980713938291729 so I want to ask what is this number represents 2024-04-23 15:23:41.869548: Pseudo dice [0.8399] and what this represent : 2024-04-23 15:23:41.911069: Yayy! New best EMA pseudo Dice: 0.8196
Here is my understanding. The nnUNet is trained based on patches, so the actual input data fed into the network are patches sampled from the original image. The dice metric during training is also calculated based on these patches. After training, nnUNet performs sliding window inference by traversing the image. The patches obtained from this sliding window sampling may differ from the sampling strategy used during training, resulting in a potential decrease in the Dice score.
Exactly as @Chasel-Chen explains it! The pseudo dice is just to see the training progress and to determine whether overfitting is a problem (green line goes down = bad!). It is not directly related to the actual validation dice! How strongly these two agree depends on the dataset
Exactly as @Chasel-Chen explains it! The pseudo dice is just to see the training progress and to determine whether overfitting is a problem (green line goes down = bad!). It is not directly related to the actual validation dice! How strongly these two agree depends on the dataset
Hi,Fabian I also have a similar confusion. In certain segmentation tasks, patch-based metrics perform exceptionally well during training. However, when performing inference on the original images, there is a significant difference in Dice scores, even when evaluated on the training set. I have attempted to modify the sampling strategy (such as shifting the generated bounding boxes), but without much success. Are there any suggestions for further experimentation and improvement, such as adjusting patch size, oversample rate, or any other methods? Thank you very much.
There is no way of 'fixing' this, please just take it as it is! It is not supposed to be the same as the validation dice. Running actual validations would take too long
but what is the difference between these: 2024-04-23 15:23:41.869548: Pseudo dice [0.8399] 2024-04-23 15:23:41.911069: Yayy! New best EMA pseudo Dice: 0.8196 each one of them is related to training and which one related to validation as those happen in the same epoch
EMA stands for exponential moving average. This tracks the pseudo dice following a updated = old * alpha + (1 - alpha) * new
formula. This is the smoothed green line in the progress.png plot
Okay thanks alot, my last question, and sorry for many questions: those numbers are based on taking some random patch from the training set during epoch not the whole image or they are from validation set during the epoch : 2024-04-23 15:23:41.869548: Pseudo dice [0.8399] 2024-04-23 15:23:41.911069: Yayy! New best EMA pseudo Dice: 0.8196
they are random patches from the validation set, sampled according to the same rules as the training patches. Dice is computed over all sampled patches (pretending all patches together are one image) and not for each patch
Can anyone tell me what is wrong with my training? Any help is much appreciated. @Body123 @Chasel-Chen @FabianIsensee
Seems like
- your epoch time is really slow, see https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/benchmarking.md
- you seem to suffer from some overfitting
- one class is not learned properly
What could concretely be the problem I cant say just from what you shared
Thank you for your response, @FabianIsensee . Let me explain my task and the dataset. I am attempting to segment recurrent and non-recurrent areas in the peritumoral region. There are three labels, including background, and I have seven channels (t1, t2, t1c, flair, dti_fa, dti_ra, and dti_md). FYI, the recurrent area is tiny compared to the non-recurrent area, possibly a class imbalance problem. I was thinking of giving more weight to the minority class in the loss function, but I was confused about where to change. There are several losses! and also need to call that loss in the nnUNetTrainer.py? pls enlighten me in this matter.
now another problem arose,
It looks like you created a class that has the same name as the nnUNetTrainer but does not inherit from it? When making changes, please remember to create trainer classes with unique names and make them inherit from nnUNetTrainer.
It seems like one of your classes is very difficult to learn or was annotated too inconsistently. Giving it more weight will not solve the problem. Looking at your training loss it seems like the model can learn the class on the training data but it fails to generalize to the validation set