nnUNet
nnUNet copied to clipboard
Reproducibility of prediction result
Hello, Febian. Thanks for your sharing your great code.
I had an issue on reproducibility of prediction result. I tried 3d_fullres inference with I use same weight without ensemble, however, the results were slightly different. In my opinion, TTA may cause this problem. Is there something I can do? (off the do_tta or fix seed?)
Regards, Keewon Shin.
Can you please give some details? Dataset? Configuration? What EXACTLY did you run/try to do?
I did inference with my own dataset, and I was wondering if anyone else had the same problem with inference results as me. I'll try to figure out the problem more clearly and contact you soon. thanks.
You issue really makes little sense to me. You have not given any details on what the problem is. Does the segmentation not work? Are the metrics not what you would expect?
As mentioned above, I trained a 3d_fullres nnU-Net using my own dataset and expected to get the same results if I performed inferences on the same weights.
The picture above is the result of repeated inference with the same model weights, and voxel counting was confirmed with ITKSNAP. The voxel counting of the left and right cases was slightly different. (left 35836, right 35837) Not all cases were like that, however, this happened with a frequency of 1 in 10 for me.
OK so you run inference twice and get two different results? That is strange and should not happen, but I also do not have an explanation for this. There is no random component in inference, so it's certainly not something nnU-Net does. It seems like this is not critical, however. Just one pixel
Hi @kevinkwshin , did this resolve your issue?
Just a quick note to say that I've hit the same issue with reproducibility. As for @kevinkwshin, it is only a small number of individual voxels that are different. Trying to find the differences manually is almost impossible, because they are so small. It also seems to be hardware dependent.
I saved the input image and the segmentation output to nifti files, where the timestamp is part of the filename, and run the segmentation in a loop. Then I calculate the md5sum of the files. The input image always gives the same md5sum. The output image has a few different md5sums, with multiple duplicates of each.
Running the exact same code on a different machine gives completely repeatable output - the same md5sum every time.
I suspect it is due to underlying torch/CUDA libraries. Torch has a list of operations which are potentially non-deterministic here: https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
As far as I am aware the non-deterministic nature of these libraries mostly (only?) affects the backward pass. So inference should always be the same. We do make changes to our inference code from time to time though. Recently we had to change the scaling of the gaussian importance weighting because it lead to nans on some datasets. These changes will impact the prediction in slight ways such as you are describing in your post. So if you really need 100% consistent results I'd recommend installing always exactly the same version of nnU-Net
This was the exact same version of nnU-Net. I was running the code in a loop.
$ md5sum *seg*
925d53930eb1ba88c506f904d27ee625 seg_13:40.nii.gz
925d53930eb1ba88c506f904d27ee625 seg_13:44.nii.gz
ad2eb9a7f1074bc05642434a75b6ba04 seg_13:47.nii.gz
ad2eb9a7f1074bc05642434a75b6ba04 seg_13:50.nii.gz
925d53930eb1ba88c506f904d27ee625 seg_13:54.nii.gz
ad2eb9a7f1074bc05642434a75b6ba04 seg_13:58.nii.gz
ad2eb9a7f1074bc05642434a75b6ba04 seg_14:01.nii.gz
925d53930eb1ba88c506f904d27ee625 seg_14:05.nii.gz
It might be hard to detect unless you're really looking for it, because:
- it's hardware specific (observed on an RTX3050, not observed on T4 or GTX1080)
- it's only a tiny fraction of individual voxels
Other people have observed non-determinism during inference https://stackoverflow.com/questions/72979303/why-is-pytorch-inference-non-deterministic-even-when-setting-model-eval
The pytorch page on non-deterministic algorithms has an example that affects forward pass, and mentions convolutions as something that can be non-deterministic. https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
Anyway, I don't think this is a big deal, probably not worth re-opening the ticket. I just wanted to let Kevin (and anybody else who comes across this) know he wasn't imagining things.
Thanks for the experiments and clarification!