Memory shortage issue while using nnUNetV2 to predict a larger CT scan.
Dear nnU-Net Team,
I encountered a memory shortage issue while using nnUNetV2 to predict a larger CT scan.
The content of planjson is as follows:
"plans_name": "nnUNetPlans",
"original_median_spacing_after_transp": [
1.5,
1.5,
1.5
],
"original_median_shape_after_transp": [
227,
227,
240
],
"image_reader_writer": "NibabelIOWithReorient",
"transpose_forward": [
0,
1,
2
],
"transpose_backward": [
0,
1,
2
],
"3d_lowres": {
"data_identifier": "nnUNetPlans_3d_lowres",
"preprocessor_name": "DefaultPreprocessor",
"batch_size": 2,
"patch_size": [
128,
128,
128
],
"median_image_size_in_voxels": [
196,
196,
206
],
"spacing": [
1.7389111114500002,
1.7389111114500002,
1.7389111114500002
],
"normalization_schemes": [
"CTNormalization"
],
"use_mask_for_norm": [
false
],
"UNet_class_name": "PlainConvUNet",
"UNet_base_num_features": 32,
"n_conv_per_stage_encoder": [
2,
2,
2,
2,
2,
2
],
"n_conv_per_stage_decoder": [
2,
2,
2,
2,
2
],
When I predict a CT with 268 layers and Z-spacing of 5, the program will first generate a predicted logits of size [25, 893, 465, 465] to store the resample predicted logits. (There are 25 labels of this model).
My physical machine only has 32GB of memory, and due to the 'predicted _logits' occupying too much memory, OOM may occur. We attempted to change 'predicted _logits' to float16 and use the latest pytorch prediction to only partially alleviate the situation.
I would like to make a suggestion. Can we declare only one 'predicted _logits' with a size of [len (labels), z, y, x] in advance, and immediately resample it after predicting a patch (prediction=self, internal may be mirror_and_predict...), and then save it to the corresponding location in the 'predicted _logits'. This way, the total memory occupied during the entire prediction period will be less.
If modifying this method is laborious, an additional judgment can be added. If the size after resampling is too large, we can try splitting it from the Z-axis and predicting it in 2 or 3 times. Alternatively, a quick script can be provided to modify the size of the system's Swap memory. When the system memory is insufficient, Swap memory can be used to prevent the program from directly OOM or getting stuck.
We would greatly appreciate it if the issue could be fixed. At the same time, this can also enable nnUNet to run on devices with lower configurations.
Sincerely,
Allen Wang
Hi @goodsave,
I hope you are well. I suggested a solution to reduce memory during inference - validation and prediction (#2881). I know your image is quite big, and you have limited memory, but maybe you could give this a try and see it behaves.
Hey @goodsave, Unfortunately, your proposal of "just resampling a crop" and moving it is not trivial since the voxels at the boundaries would have boundary artifacts. Hence, one would have to take this into account during the implementation of such a resampling. If you are willing to invest time, feel free to open a Pull Request that does the proposed resampling as you mentioned. If you could include a quick test case to verify the resampling works as intended, we would be happy to integrate this into the official repository.
A bit off-topic/FYI: nnUNet already tries to minimize the Memory footprint by only resampling one class at a time, as you can see here.
Regarding the "include a script to increase Swap": This is clearly outside the scope of what nnU-Net is developed for and can be very OS dependent, so we will not provide some automated script meddling with a Users OS.
Also did you find a solution yourself or did you try using the fix that Shrajan proposed to you?
Lastly, since this Issue has been stale for 3 weeks, let me know if this issue is still persisting; otherwise, I will close this issue next week.
Best, Tassilo