nnUNet icon indicating copy to clipboard operation
nnUNet copied to clipboard

Adding a new `nnUNetTrainer` variant for extensive data augmentations on `GPU`

Open NathanMolinier opened this issue 3 months ago • 5 comments

Description

Hello nnUNet Community,

I am currently working on a project exploring how data augmentations can improve model performance for 3D segmentation tasks on MRI images, using nnUNet as a baseline for my experiments.

As part of this effort, I introduced a new trainer called nnUNetTrainerDAExt (shared in another issue, and potentially a future PR), which incorporates a broader range of augmentations and has shown promising improvements during inference.

However, the main drawback of this trainer—as with other augmentation-based trainers in the variants folder—is that augmentations are performed on the CPU. This creates a significant bottleneck, especially when combining multiple augmentations, while the GPUs remain underutilized. For instance, with nnUNetTrainerDAExt, depending on the number of CPU cores available (since transforms use multiprocessing), one epoch can take anywhere between 700 seconds and over 3000 seconds—leading to training times that can stretch beyond a month.

To address this issue, I developed a new GPU-based augmentation trainer. In my tests, this trainer achieved a 3× reduction in epoch time compared to the CPU-based version (from ~700s down to ~230s on my cluster).

At the moment, this trainer doesn’t yet include as many augmentation options as nnUNetTrainerDAExt, but I am actively working on expanding it—either by implementing new transforms or by integrating existing ones from libraries such as Kornia or torchvision. Admittedly, finding 3D transforms is more challenging than for 2D.

If anyone is interested in contributing to the development of this trainer, your help would be greatly appreciated!

Related issues

  • https://github.com/MIC-DKFZ/nnUNet/issues/1453

NathanMolinier avatar Sep 25 '25 15:09 NathanMolinier

I am very interested in this issue as well as i think moving to GPU makes a lot of sense. However i guess we need to keep at least one transform to do a random crop to minimize CPU-GPU transfer, otherwise it means you transfer the whole volume and not a patch right? From what i can see the cropping in batch_generators is done at the same time as all spatial augmentations, so to minimize overhead due to the grid_sample function on cpu, probably we want to use an explicit, faster crop transform that will crop a slightly bigger patch than what the user is asking for ?

an alternative to this is, i think, to simply re-write grid_sample on the cpu side to use multiple threads. See https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/GridSampler.cpp#L41. This function is not parallelized over voxels.

etienne87 avatar Nov 12 '25 09:11 etienne87

Hey @etienne87 ! Right now I am keeping the patch created by the data loader and not transferring the full volume. Then, all my GPU transformation are done on the patches and I ensure that all of them have the same shape after running the augmentations. For example, when applying cropping, I then pad to keep the original patch shape.

Feel free to have a look at my implementation and maybe give it a try. I recently updated the readme to tell how to run the trainer easily.

NathanMolinier avatar Nov 12 '25 18:11 NathanMolinier

Ah right! i did not realize there was already cropping happening inside nnUNetDataLoader with the "initial_patch_size" and "final_patch_size" arguments. I will try your variant and benchmark agains't the dumb solution to accelerate grid_sample. Just for me to understand, why do we need the development of new GPU augmentations? isn't batchgenerators-v2 already using pytorch?

etienne87 avatar Nov 13 '25 10:11 etienne87

why do we need the development of new GPU augmentations? isn't batchgenerators-v2 already using pytorch?

You’re absolutely right — batchgenerators-v2 does use PyTorch. However, since the transformations are handled by the data loader, which currently runs only on the CPU, these transforms never get transferred to the GPU, and we lose the benefit of parallel GPU computation.

In my case, rather than completely rewriting the data loader, I minimized its workload — mainly keeping it responsible for extracting patches from the data — and moved all augmentations to occur right before feeding the patches into the network. As I mentioned in my original post, when using extensive data augmentation, relying solely on the CPU quickly becomes a major bottleneck that significantly slows down training.

Also, because transforms were handled in the data loader, they did not account for the additional batch dimension.

NathanMolinier avatar Nov 13 '25 15:11 NathanMolinier

I wanted to reproduce your idea on a fork of nnUNet here

i have done a fork of batchgenerators-v2 where if fixed the spatial augmentation code to take into account the device.

i get good timing, validating the approch i think (i am testing on a dataset with 300x300x300 volumes and 4 labels).

0.20424795150756836s for getting batch (dataloading + crop_and_pad_nd) 0.034s for gpu augmentation (move to device + train transforms (only spatial for now)) 0.232s for train step

right now, weirdly getting the batch is almost the bottleneck, i don't understand why, putting allowed_num_processes to 0 is much better than letting to get_allowed_n_proc_DA() (almost a 2x decrease in number of iterations / s)

I wonder why that it is? and also if with the approch of GPU transformation we can simply "repeat" the same batch multiple times (but doing many different augmentations) (at least for early epochs). This could lead to even faster trainings if dataloading is a known bottleneck.

EDIT: i have found why! IPC communication is the bottleneck in my case. i target a size of 192x192x192 which means before final augmentation a size of ~300^3 which is much bigger. Since the dataloader is not doing much besides loading, i figured i can just use pure python threads and not multiprocessing, lowering this runtime significantly (from 0.5s to 0.5ms).

I had to code an alternative to NonDetMultiThreadAugmenter (12 processes) to GPU Augmentations with a ThreadedGPUAugmenter

With this "thread-only" solution (no multiprocessing), i match original CPU dataloading timing on 2 use-cases datasets:

  • high-resolution 3d arteries (patches are 256x256x256). => 4 it/s (cpu or gpu)
  • mid-resolution organs (patches are 80x160x160) => 15 it/s (cpu or gpu).

I have not timed in a constrained-cpu scenario, my guess is that GPU mode will remain constant and CPU mode might decrease a little if your CPU is saturated with other trainings.

etienne87 avatar Nov 14 '25 15:11 etienne87