nnUNet icon indicating copy to clipboard operation
nnUNet copied to clipboard

Inference acceleration(skimage.transform.resize is slow)

Open BigPandaCPU opened this issue 2 years ago • 3 comments

Hi @FabianIsensee, Thanks for your excellent work.

I found that the resample_data_or_seg() is slow. I have test three different interpolation methods.The resample times below.

OS:

  win10, arm 64G,  python3.7,  RTX1050, 2GB

Test data:

origin:
          size(1, 1747, 512, 512,)    spacing(0.625,  0.94140601, 0.94140601)
resampled:
         size(1, 1747, 494, 494)     spacing(0.625, 0.97656202, 0.97656202)

Interpolate method1: skimage.transform.resize() 173s

time1

Interpolate method2: scipy.ndimage.zoom() 118s

time2

Interpolate method3: torch.nn.function.interpolate() 1.39s

time3

The torch interpolate function is the fast and is good for image data and segmentation_softmax data, not good for seg data(it well cause anti_aliasing,The pic below shows ).

time4

So on Inference, we can use torch interpolate() function to speed up the resample time. But while training, to preprocess traininig data, we should use skimage resize() function.

My code below: preprocessing.py resample_data_or_seg()

     if np.any(shape != new_shape):
          if do_separate_z:
          ...   ...
          else:
              reshaped = []
  
             ###  add torch interpolate function ####
             ### if is seg data, use skimage resize()  ###
             ###       else use torch interpolate()       ### 

              if is_seg:
                  use_pytorch_interpolate = False
              else:
                  use_pytorch_interpolate = True
  
              if use_pytorch_interpolate:
                  print("no separate z, interpolate mode trilinear", )
  
                  time_start = time.time()
                  data_torch = torch.from_numpy(data).to(torch.float32)
                  data_torch = torch.unsqueeze(data_torch, 0)
                  del data
                  new_size = tuple(new_shape.tolist())
                  print("new size", new_size)
  
                  reshaped_final_data = F.interpolate(data_torch, size=new_size, mode='trilinear', align_corners=False)
                  reshaped_final_data = torch.squeeze(reshaped_final_data, 0)
                  reshaped_final_data = reshaped_final_data.numpy()
                  time_end = time.time()
                  print("seg resized time:", time_end-time_start)
              else:
                  print("no separate z, mode", order)
                  time_start = time.time()
                  for c in range(data.shape[0]):
                      reshaped.append(resize_fn(data[c], new_shape, order, cval=cval, **kwargs)[None])
                  reshaped_final_data = np.vstack(reshaped)
                  time_end = time.time()
                  print("resample time:", time_end-time_start)
          return reshaped_final_data.astype(dtype_data)
  

BigPandaCPU avatar Jun 29 '22 09:06 BigPandaCPU

Hey, the problem with torch.nn.function.interpolate is that it only allows linear or nearest neighbor resampling but we need third order spline. I would therefore like to keep things as they are. If you know any other library that can do 3d spline interpolation and is faster than skimage please let me know :-) skimage resize just calls map_coordinates from scipy just like zoom does, so these should in theory be the same (given everything else is equal)

FabianIsensee avatar Aug 23 '22 10:08 FabianIsensee

Hello Fabian,

Great work, thanks!

Have you ever considered to use cucim.skimage for resampling? cucim: https://github.com/rapidsai/cucim

It is a GPU-based implementation of the scikit-image API using cupy. So it comes with the same features as skimage.resize.

from cucim.skimage.transform import resize
import cupy as cp

def run_cucim(img, target_shape, order):
    img = cp.asarray(img)
    resampled_img = resize(img, output_shape=target_shape, order=order, mode="edge", anti_aliasing=False)
    resampled_img = cp.asnumpy(resampled_img)  # Alternative: resampled_img = cp.float32(resampled_img.get())
    return resampled_img

Here are some benchmark values I got for a CT with size (512, 512, 768) resampled to size (520, 520, 1024):

cucim.resize: 0.85 (0.09) s [on GPU] SimpleITK.resample: 7.65 (0.12) s [on CPU with sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(MAX_THREADS)] skimage.resize: 54.85 (0.77) s [default as I could not find any multithreading option] scipy.zoom: 61.59 (0.50) s [default as I could not find any multithreading option]

Values are mean (std) for n=10 runs

SimpleITK might also be an option as it does support spline interpolation.

dhaberl avatar Aug 25 '22 10:08 dhaberl

Hi @dhaberl thanks for pointing this out! That seems very interesting and warrants a closer look! There are some things I would need to consider:

  • so far preprocessing does not require a GPU and I am not sure whether I want to include this constraint
  • running the resampling code with multiprocessing is not strictly necessary as we use multiple workers during preprocessing, each processing one image. So the throughput should be the same
  • currently, preprocessing and segmentation export during inference are handled via background workers. So as long as the GPU is the bottleneck (which it usually is) skimage is actually not a problem
  • cuda resampling only works as long as there is enough GPU memory. This could get really annoying really quick

FabianIsensee avatar Aug 31 '22 10:08 FabianIsensee