Implement TTA Batch Processing to Improve Inference Speed
Summary: Proposing the integration of Test Time Augmentation (TTA) with batch processing in nnUNet to enhance inference efficiency, particularly evident in larger 3D datasets. Demonstrated improvements of 5%-8% in speed with validated results on the AMOS2022 dataset.
Implementation Details:
- Python Version: 3.11
- PyTorch Version: 2.2.2+cu121
- Model: [64, 160, 192] patch size on NVIDIA RTX 3090, 24GB VRAM.
Results:
-
Mirror Axes (0, 1, 2):
ID TTA No Batch (s) TTA Batch Size=2 (s) TTA Batch Size=4 (s) TTA Batch Size=8 (s) amos_0247 51 49 48 49 amos_0111 80 76 74 75 amos_0173 129 122 119 120 -
Mirror Axes (0, 1):
ID TTA No Batch (s) TTA Batch Size=2 (s) TTA Batch Size=4 (s) amos_0247 26 24 25 amos_0111 40 38 38 amos_0173 64 61 60 -
Mirror Axis (0):
ID TTA No Batch (s) TTA Batch Size=2 (s) amos_0247 13 12 amos_0111 20 19 amos_0173 32 31
VRAM Usage:
- Detailed VRAM Consumption by TTA Batch Size:
Batch Size VRAM (GB) - Axes (0, 1, 2) VRAM (GB) - Axes (0, 1) VRAM (GB) - Axis (0) 1 7.57 7.57 7.57 2 9.24 9.24 9.24 4 12.37 12.37 - 8 17.04 - -
Recommendations:
- Use TTA Batch Size=4 for configurations with three Mirror Axes (0, 1, 2).
- Use TTA Batch Size=2 for configurations with fewer Mirror Axes.
The TTA batch processing approach has been thoroughly tested on the AMOS2022 dataset, showing consistent results with the original setup.
Hi, thanks for the contribution + extensive benchmarking! That helps a lot in seeing the value! If you would like us to include this, please make it an optional parameter people can set when calling nnUNetv2_predict. This should also (like all the other parameters) be set in the init of the nnUNetPredictor class. The reason I want this to be optional is twofold:
- sometimes we just don't have the VRAM to justify doing that
- in case of limited VRAM, there are other VRAM-hungry features (
perform_everything_on_device) that are more impactful for inference throughput and should be prioritized over batching TTA
Best, Fabian
Hi Fabian,
Thanks for your feedback! Based on your suggestions, I have now made the "use_batch_tta" an optional parameter in the nnUNetPredictor class, which can be controlled via the parser argument "disable_batch_tta". This allows users to opt-in or out of batch TTA based on their VRAM capacity and priorities.
Please let me know if further adjustments are required.
Best, Pengcheng