displacement field and warping label maps
I am using the python interface to register two images:
self.model(moving_image.cuda(), fixed_image.cuda())
deformed_image = self.model.warped_image_A.cpu().squeeze()
deformation = self.model.phi_AB_vectorfield.cpu().squeeze()
when plotting the deformed image it looks great! however, when plotting the deformation there seems to be little deformation visible which is inconsistent with the warped image. same holds for the warped label maps, they are nearly not moving. I am warping with F.grid_sample() after normalizing the displacement field.
Do you have an idea why this is not working? Or how can i use the py interface to warp the label maps similarly to the warped image result?
Thanks a lot in advance! :)
Thank you for your question! Getting the deformation field into a format with which you can use F.grid_sample isn't quite trivial. Could you post the code which generates the figures you are concerned about?
Thanks for the quick answer! :)
I am using this code to plot the deformation field as a grid, where I can't see any deformation while the warped image clearly is deformed:
`deformation = model.phi_AB_vectorfield.cpu().squeeze() deformation= torch.permute(deformation, (1,2,3,0))
...
in plotting code:
ax = fig.add_subplot(3, c, r + 4) axes = [1, 2]
fieldAx = deformation[..., axes].take(half_slice, axis=0) # extract middle slice from displacement volume plot_deformation_field(ax, fieldAx)
plotting function for plotting displacement field as a warped grid
def plot_deformation_field(ax: plt.Axes, disp: np.ndarray, interval: Optional[int] = 3, title: Optional[str] = None, color: Optional[str] = 'cornflowerblue') -> None:
assert disp.shape[0] == 2, "Displacement field should have shape (2, H, W)"
# convert displacement from unit
# disp[0, ...] = float(disp.shape[1] - 1) * disp[0, ...] / 2.0
# disp[1, ...] = float(disp.shape[2] - 1) * disp[1, ...] / 2.0
H = disp.shape[1]
W = disp.shape[2]
y, x = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
# Create sampling grid
sample_y = y[::interval, ::interval] - disp[0, ::interval, ::interval]
sample_x = x[::interval, ::interval] - disp[1, ::interval, ::interval]
# Plot deformed grid
ax.plot(sample_x, sample_y, color=color, linewidth=0.34)
ax.plot(sample_x.T, sample_y.T, color=color, linewidth=0.34)
ax.set_frame_on(False)
`
Is the returned deformation in voxel units or is it normalized? I already tried denormalizing before plotting but this doesn't seem to be correct.
Best, Anna
Is the returned
Hi! the short answer is that phi_AB_vectorfield is already a coordinate field, not a displacement field (a coordinate field being what you get when you add displacement to the identity map), and the best way to plot the deformation at slice N from pytorch is
just the deformation field:
plt.contour(phi_AB_vectorfield[:, 1, N, :, :].cpu().detach())
plt.contour(phi_AB_vectorfield[:, 2, N, :, :].cpu().detach())
deformation field and warped image:
plt.imshow(warped_image_A[:, :, N, :, :].cpu().detach())
plt.contour(phi_AB_vectorfield[:, 1, N, :, :].cpu().detach())
plt.contour(phi_AB_vectorfield[:, 2, N, :, :].cpu().detach())
The long answer is that the transform is natively a python function phi_AB, and phi_AB_vectorfield is just a cached value of the derived value phi_AB(model.identity_map) which we computed already for the loss, and often want to re-use in performance sensitive applications.
I think what this is revealing is a hole in the documentation, so I'll write it up here and then add it to the readthedocs once I've helped you with your questions.
The basic concept needed is that in icon, every voxel in an image is associated with a coordinate in [0, 1]^3. This is conveniently accessible in model.identity_map: the voxel at image_A[0, 0, 23, 45, 67] has floating point coordinate model.identity_map[0, :, 23, 45, 67]. This is useful internally so that we can use the same coordinates throughout the multiscale registration performed inside the neural network- the different scales have their identity_map set to match their expected image size.
With these coordinates, any tensor of shape [B, N, H, W, D] (an image or a transform!) can be interpreted via interpolation as a function from R^D to R^N, and we expose this interpretation as model.as_function(tensor) which converts a tensor to a function.
Then, transforms are typically represented as python functions that take in a tensor of shape [Batch, n_dimensions, H, W, D] and return a tensor of the same shape. This function can then be applied to warp an image or compose transforms in a way that closely matches the notation used in papers:
I^A \circ \phi^AB
maps to
lambda coords: model.as_function(image_A)(phi_AB(coords))
which is a function from coordinates to intensities. Internally, we pass images and transforms around in this form so that as they are warped with multiple transforms, we only ever resample the image once at the end, maintaining sharpness. To turn this warped image from a function back to a tensor, we evaluate it at the coordinates stored in identity map:
model.as_function(image_A)(phi_AB(model.identity_map))
or, to save recomputing the last term,
model.as_function(image_A)(phi_AB_vectorfield)
So, to warp a label map of the same shape as the image passed to the network,
model.as_function(label_A)(phi_AB_vectorfield)
would mostly work, although it would use linear interpolation which is not correct for labelmaps. We should add an argument to as_function to allow specifying nearest neighbor interpolation. In general, we recommend warping label maps using the ITK interface whenever possible, as this allows the user to specify their interpolation freely, and allows warping label maps at their original resolution, respecting spacing and orientation metadata.
This all is a lot of complexity compared to just using grid_sample, but unfortunately it's needed for internal development so that we can correctly implement the architectures and losses described in our papers.
Turns out all the necessary tools are available in your code, they are just a bit hidden.
I assumed you wrote you 3D Slicer extension in Python, so I looked there to see how you implemented the interface between your model and ITK (as 3D Slicer uses ITK).
To convert the phi_AB of the model into an ITK transform you do itk_wrapper.create_itk_transform(phi_AB, model.identity_map, image_A, image_B) and then the question remains how to transform this into a displacement field that can be used with PyTorch - here is how you can do that:
Do all of this after inference:
phi_AB = self.model.phi_AB(self.model.identity_map)
AB = itk_wrapper.create_itk_transform(
phi_AB, self.model.identity_map, moving_image_itk, fixed_image_itk)
#create_itk_transform returns a composite transform made out of a linear transform composed with a displacement field composed with a linear transform - so we need to flatten this transform with a small helper
def flatten_with_reference_image(
composite: "itk.TransformBase",
reference_image: "itk.Image",
) -> "itk.DisplacementFieldTransform[itk.F, 3]":
dim = 3
VecPixel = itk.Vector[itk.F, dim]
OutImage = itk.Image[VecPixel, dim]
filt = itk.TransformToDisplacementFieldFilter[OutImage, itk.D].New()
filt.SetTransform(composite)
filt.UseReferenceImageOn()
filt.SetReferenceImage(reference_image)
filt.Update()
disp_tx = itk.DisplacementFieldTransform[itk.F, dim].New()
disp_tx.SetDisplacementField(filt.GetOutput())
return disp_tx
AB = flatten_with_reference_image(AB, moving_image_itk)
# now we just need to extract the array of the displacement field from the flattened transform
displacement = AB.GetDisplacementField()
displacement = torch.from_numpy(itk.array_from_image(displacement))
That's it!
However, I'm using unit displacements and the convention used by Nibabel, so I have to transform the displacement field like so:
def displacement_to_unit_displacement(displacement: torch.Tensor) -> torch.Tensor:
"""
Convert a displacement field to a unit displacement field.
The standard unit of displacement is a half-image, so a displacement vector of magnitude 2
means that the displacement distance is equal to the side length of the displaced image.
"""
disp = torch.zeros_like(displacement)
for dim in range(displacement.shape[-1]):
disp[..., dim] = 2.0 * displacement[..., dim] / \
float(displacement.shape[-dim - 2] - 1)
return disp
displacement = displacement.permute(2, 1, 0, 3).contiguous()
displacement[..., 0] = -displacement[..., 0]
displacement[..., 1] = -displacement[..., 1]
displacement = displacement_to_unit_displacement(displacement)
displacement = displacement[..., [2, 1, 0]]