High fluctuation in loss and Dice during 5-fold cross-validation
Hello,
I trained nnU-Net v2 (2D configuration) on a dataset of cerebral DSA images for automatic segmentation of non-perfused brain tissue. I performed a 5-fold cross-validation using the default nnU-Net pipeline. My dataset contained 120 images, 100 voor training/validation and 20 voor testing.
For each fold, the training completed successfully, but the training and validation loss curves fluctuate heavily.
When I trained the fold_all model , the convergence was very stable and the Dice quickly reached ≈ 1.0.
Is it normal to observe this kind of noisy behaviour in the cross-validation folds? Could it be due to small dataset size, data imbalance, or DSA image variability?
Thank you for your time and for developing nnU-Net!
I'm not an expert, but I think I can answer part of your question. Hopefully this helps you.
Fold "all" uses ALL the training data as validation data. This usually leads to high accuracies as this becomes an overfit model.
We normally want a split between the training and validation data. One example is 80/20 split, with 80% of the dataset in training and 20% in validation.
K-Fold is one way that handles this by performing the split. But if you don't know which images would do best in your validation dataset, we train 5 folds with different images in the validation set. You can see these sets in the splits_final.json in the preprocessed directory. (This file gets generated after you start a fold n training.)
You would train all 5 folds, and then perform a cross validation to get your final score/prediction for an image.
The k-fold in my experience has always resulted in a slightly worse score than the fold all model, but with fold all being overfit, it makes sense.
Example: Depending on your dataset, the autogenerated splits may not perform well. I have run into this issue myself.
I have a dataset with 15 images. But 5 images contain Label 1, 5 images contain Label 2, and 5 images contain Label 3. The autogenerated splits for this type of dataset end up with little or no images for Label 2, as an example. This means that fold will not learn the label and ends up performing poorly.
I needed to customize my splits, so that each validation set at least contained an equal number of Label 1, 2, and 3. I have found this to be one reason as to why our graphs can be so chaotic and not learning correctly.
This truly does come down to your dataset and how imbalanced it might potentially be.
Are you able to provide any additional information on your dataset? How many labels? Do all images contain all labels?
Thanks for the explanation! In my case, I only have one label, the non-perfused brain tissue. So label imbalance isn’t an issue here. when you say “perform a cross validation to get your final score/prediction for an image,” do you mean averaging the predictions from all 5 folds or ensembling them?
That's good you don't have label imbalance. I have too many datasets with label imbalances...
So for cross validation, the developers setup a nifty function that will find your best configuration, and will provide you with the correct predict and post processing commands.
After you have trained all 5 folds, you can run the nnUNetv2_find_best_configuration command.
Example:
(caml) [vmiller@gluskap caml]$ CUDA_VISIBLE_DEVICES=1 nnUNetv2_find_best_configuration 100 -c 3d_fullres -p nnUNetResEncUNetMPlans
***All results:***
nnUNetTrainer__nnUNetResEncUNetMPlans__3d_fullres: 0.977060910523444
*Best*: nnUNetTrainer__nnUNetResEncUNetMPlans__3d_fullres: 0.977060910523444
***Determining postprocessing for best model/ensemble***
Removing all but the largest foreground region did not improve results!
Removing all but the largest component for 1 did not improve results! Dice before: 0.97716 after: 0.34112
Removing all but the largest component for 2 did not improve results! Dice before: 0.978 after: 0.60801
Removing all but the largest component for 3 did not improve results! Dice before: 0.97602 after: 0.20067
***Run inference like this:***
nnUNetv2_predict -d Dataset100_BrightSpeed_CalBlocks+P3P4P5_TH -i INPUT_FOLDER -o OUTPUT_FOLDER -f 0 1 2 3 4 -tr nnUNetTrainer -c 3d_fullres -p nnUNetResEncUNetMPlans
***Once inference is completed, run postprocessing like this:***
nnUNetv2_apply_postprocessing -i OUTPUT_FOLDER -o OUTPUT_FOLDER_PP -pp_pkl_file /opt/datasets/FCT/results/Dataset100_BrightSpeed_CalBlocks+P3P4P5_TH/nnUNetTrainer__nnUNetResEncUNetMPlans__3d_fullres/crossval_results_folds_0_1_2_3_4/postprocessing.pkl -np 8 -plans_json /opt/datasets/FCT/results/Dataset100_BrightSpeed_CalBlocks+P3P4P5_TH/nnUNetTrainer__nnUNetResEncUNetMPlans__3d_fullres/crossval_results_folds_0_1_2_3_4/plans.json
This will also generate a json and txt file with this information in your /path/to/dataset/results/DATASET_NAME/ that contains this information as well.
They even take it a step farther, and if you trained a 2d, 3d_fullres, and 3d_cascade_fullres model, they have an ensemble command as well.
If you look at this PR you can find a CLI Catalog I made probably a year ago. #2834
You can find a markdown file that is mostly accurate. I believe there's been some updates and now there's a couple missing parameters for some of the functions I haven't gotten around to updating yet. I use this to help figure out what all the different commands nnUNet has.