torchkbnufft
torchkbnufft copied to clipboard
KB-NUFFT hangs when used in PyTorch DataLoader with num_workers > 0
Hi,
I'm trying to perform on-the-fly data undersampling in my PyTorch dataset. To do this, I perform a Toeplitz NUFFT in the __getitem__
function of my Dataset
class. This works as expected. Now, I want to to batching, so I wrap the PyTorch Dataset
in a PyTorch DataLoader
. This works as expected when num_workers=0
. However, when num_workers
is non-zero, computation of the NUFFT seemingly enters an infinite loop.
Expected behaviour
Performing a NUFFT in parallel using multiple workers should result in undersampled images.
Observed behaviour
Sampling the dataloader results in a hanging script, seemingly entering an infinite loop.
Extra information
- A minimally-working example of this behaviour is attached below.
- This behaviour is observed with Torch-kbnufft version 1.3.0 and PyTorch version 1.12.
- CUDA is not used in this example, but it also happens when the NUFFT is computed on the GPU.
- It is not limited to a Toeplitz NUFFT but also happens with the table NUFFT.
- Density compensation has no influence on the observed behaviour
Minimal example
import torch
from skimage.data import shepp_logan_phantom
from skimage.transform import rescale
import numpy as np
import torchkbnufft as tkbn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
NUM_WORKERS=0 # set to > 0 to hang
class DataUndersampler(Dataset):
def __init__(self, undersampling_factor, spatial_scaling_factor):
self.image = shepp_logan_phantom()
self.image_shape = self.image.shape
self.undersampling_factor = undersampling_factor
self.spatial_scaling_factor = spatial_scaling_factor
# Need the original size to determine the undersampling factor and NUFFT operators
self.orig_size = self.image_shape
self.image_shape = (self.image_shape[0] // self.spatial_scaling_factor, self.image_shape[1] // self.spatial_scaling_factor)
# Create an oversampled grid
spokelength = self.image_shape[0] * 2
self.grid_size = (spokelength, spokelength)
# Generate A LOT of spokes, pick a starting point at random
nspokes = 2000
# Sample enough spokes to achieve undersampling factor
self.spokes_to_sample = int((self.orig_size[0] * np.pi / 2) / self.undersampling_factor)
# Generate a golden angle radial trajectory
ga = np.deg2rad(180 / ((1 + np.sqrt(5)) / 2))
kx = np.zeros(shape=(spokelength, nspokes))
ky = np.zeros(shape=(spokelength, nspokes))
ky[:, 0] = np.linspace(-np.pi, np.pi, spokelength)
for i in range(1, nspokes):
kx[:, i] = np.cos(ga) * kx[:, i - 1] - np.sin(ga) * ky[:, i - 1]
ky[:, i] = np.sin(ga) * kx[:, i - 1] + np.cos(ga) * ky[:, i - 1]
self.ky = np.transpose(ky)
self.kx = np.transpose(kx)
# 1D Ramlak. Needed for density compensation. Density is a linear function
# depending on the distance of the center of k-space...
ram_lak = np.abs(np.linspace(-1, 1, spokelength + 1))
ram_lak = ram_lak[:-1]
#... except for the center, we know exactly how often we sample that,
# namely as many times as the number of spokes
middle_idx = len(ram_lak) // 2
ram_lak[middle_idx] = 1/(2 * self.spokes_to_sample)
self.ram_lak = ram_lak
def __len__(self):
return 1
def __getitem__(self, index):
if self.spatial_scaling_factor > 1:
img = rescale(self.image, 1/self.spatial_scaling_factor).astype(np.complex)
else:
img = self.image.astype(np.complex)
if self.undersampling_factor > 1:
img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
toep_ob = tkbn.ToepNufft()
# We pick a random starting spoke
offset = np.random.choice(range(self.ky.shape[0]-self.spokes_to_sample))
# And select as many subsequent spokes to reach the desired undersampling factor
# todo: use continuing trajectories for cine?
selected_ky = self.ky[offset:offset+self.spokes_to_sample].flatten()
selected_kx = self.kx[offset:offset+self.spokes_to_sample].flatten()
ktraj = torch.tensor(np.stack((selected_ky, selected_kx), axis=0))
# Repeat and reshape the ram-lak so every spoke is density-compensated
ram_lak_t = torch.from_numpy(np.tile(self.ram_lak, self.spokes_to_sample)).unsqueeze(0).unsqueeze(0)
# Calculate the really efficient Toeplitz kernel to compute the NUFFT
dcomp_kernel = tkbn.calc_toeplitz_kernel(ktraj, self.image_shape, weights=ram_lak_t, norm='ortho',numpoints=(3,3)) # with density compensation
# And in a single step, compute the radial k-space and back to image space
img = toep_ob(img_tensor, dcomp_kernel, norm='ortho').abs().squeeze().numpy()
# renormalize the output, because undersampling can change this.
img /= np.max(img)
return np.abs(img)
# Create dataset
dset = DataUndersampler(
undersampling_factor=2,
spatial_scaling_factor=1)
# this works fine
undersampled_img = dset[0]
# From here, the observed behaviour emerges.
dloader = DataLoader(dset,
shuffle=False,
batch_size=1, num_workers=NUM_WORKERS)
# this statement hangs when num_workers > 0
undersampled_img = next(iter(dloader))
Hello @maartenterpstra, I think the issue is due to the table-based NUFFT, which is used inside tkbn.calc_toeplitz_kernel
.
Does every sample have a different trajectory? If they're all the same, you could apply NUFFT outside the dataloader.
Hi @mmuckley. I was also thinking that as a workaround I could compute the NUFFT for a single batch outside the dataloader. In general, every sample has a different trajectory but the same number of spokes. Would this be possible?
Hello @maartenterpstra, it may be more efficient to loop over the list or use a batched NUFFT. The batched NUFFT is good for a large number of small NUFFTs. You can see how to use it here.
I also opened #74 as a potential enhancement with a pointer to where the code controls threading if you'd be interested in that route.