nnUNet
nnUNet copied to clipboard
Inference acceleration(skimage.transform.resize is slow)
Hi @FabianIsensee, Thanks for your excellent work.
I found that the resample_data_or_seg() is slow. I have test three different interpolation methods.The resample times below.
OS:
win10, arm 64G, python3.7, RTX1050, 2GB
Test data:
origin:
size(1, 1747, 512, 512,) spacing(0.625, 0.94140601, 0.94140601)
resampled:
size(1, 1747, 494, 494) spacing(0.625, 0.97656202, 0.97656202)
Interpolate method1: skimage.transform.resize() 173s
Interpolate method2: scipy.ndimage.zoom() 118s
Interpolate method3: torch.nn.function.interpolate() 1.39s
The torch interpolate function is the fast and is good for image data and segmentation_softmax data, not good for seg data(it well cause anti_aliasing,The pic below shows ).
So on Inference, we can use torch interpolate() function to speed up the resample time. But while training, to preprocess traininig data, we should use skimage resize() function.
My code below: preprocessing.py resample_data_or_seg()
if np.any(shape != new_shape):
if do_separate_z:
... ...
else:
reshaped = []
### add torch interpolate function ####
### if is seg data, use skimage resize() ###
### else use torch interpolate() ###
if is_seg:
use_pytorch_interpolate = False
else:
use_pytorch_interpolate = True
if use_pytorch_interpolate:
print("no separate z, interpolate mode trilinear", )
time_start = time.time()
data_torch = torch.from_numpy(data).to(torch.float32)
data_torch = torch.unsqueeze(data_torch, 0)
del data
new_size = tuple(new_shape.tolist())
print("new size", new_size)
reshaped_final_data = F.interpolate(data_torch, size=new_size, mode='trilinear', align_corners=False)
reshaped_final_data = torch.squeeze(reshaped_final_data, 0)
reshaped_final_data = reshaped_final_data.numpy()
time_end = time.time()
print("seg resized time:", time_end-time_start)
else:
print("no separate z, mode", order)
time_start = time.time()
for c in range(data.shape[0]):
reshaped.append(resize_fn(data[c], new_shape, order, cval=cval, **kwargs)[None])
reshaped_final_data = np.vstack(reshaped)
time_end = time.time()
print("resample time:", time_end-time_start)
return reshaped_final_data.astype(dtype_data)
Hey, the problem with torch.nn.function.interpolate is that it only allows linear or nearest neighbor resampling but we need third order spline. I would therefore like to keep things as they are. If you know any other library that can do 3d spline interpolation and is faster than skimage please let me know :-) skimage resize just calls map_coordinates from scipy just like zoom does, so these should in theory be the same (given everything else is equal)
Hello Fabian,
Great work, thanks!
Have you ever considered to use cucim.skimage for resampling? cucim: https://github.com/rapidsai/cucim
It is a GPU-based implementation of the scikit-image API using cupy. So it comes with the same features as skimage.resize.
from cucim.skimage.transform import resize
import cupy as cp
def run_cucim(img, target_shape, order):
img = cp.asarray(img)
resampled_img = resize(img, output_shape=target_shape, order=order, mode="edge", anti_aliasing=False)
resampled_img = cp.asnumpy(resampled_img) # Alternative: resampled_img = cp.float32(resampled_img.get())
return resampled_img
Here are some benchmark values I got for a CT with size (512, 512, 768) resampled to size (520, 520, 1024):
cucim.resize: 0.85 (0.09) s [on GPU] SimpleITK.resample: 7.65 (0.12) s [on CPU with sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(MAX_THREADS)] skimage.resize: 54.85 (0.77) s [default as I could not find any multithreading option] scipy.zoom: 61.59 (0.50) s [default as I could not find any multithreading option]
Values are mean (std) for n=10 runs
SimpleITK might also be an option as it does support spline interpolation.
Hi @dhaberl thanks for pointing this out! That seems very interesting and warrants a closer look! There are some things I would need to consider:
- so far preprocessing does not require a GPU and I am not sure whether I want to include this constraint
- running the resampling code with multiprocessing is not strictly necessary as we use multiple workers during preprocessing, each processing one image. So the throughput should be the same
- currently, preprocessing and segmentation export during inference are handled via background workers. So as long as the GPU is the bottleneck (which it usually is) skimage is actually not a problem
- cuda resampling only works as long as there is enough GPU memory. This could get really annoying really quick