uniGradICON icon indicating copy to clipboard operation
uniGradICON copied to clipboard

Obtain Displacement Field at Original Input dimension

Open anudeepk17 opened this issue 1 year ago • 7 comments

Dear Authors, Thank you for this great content. I was using your model from the source and was trying to register my own data. The results are great but I wanted to obtain the phi_AB, or the net.phi_AB_vectorfield in the original image input dimensions instead of 175x175x175. I tried to resample the vectorfield using itk.resample_image_filter and the torch.nn.Interpolate fucntion but both resulted in a tensor which was not able to register the images like the field of original network size. I could not figure out how to modify phi_AB since it is a itk.COmpositeTransform format object.

Could you help me understand how to get the net.phi_AB_vectorfield or phi_AB in the dimension of the input image we give. My issue is similar to #15 but I could not get a solution from that issue. Would again like to thank you for the help in advance.

anudeepk17 avatar Aug 14 '24 20:08 anudeepk17

Hi! Are you looking for the displacement field at original resolution with displacements in physical coordinates? If so, this is doable by converting the itkCompositeTransform to a displacement field as follows:

https://colab.research.google.com/drive/1bo_CWdI4PC7YdMmlVb2jYp0Fd1ee5gW_?usp=sharing

!pip install unigradicon

!wget https://www.hgreer.com/assets/slicer_mirror/RegLib_C01_1.nrrd
!wget https://www.hgreer.com/assets/slicer_mirror/RegLib_C01_2.nrrd

!unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri --transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd

import itk

fixed_image = itk.imread("RegLib_C01_2.nrrd")
moving_image = itk.imread("RegLib_C01_1.nrrd")

transform = itk.transformread("trans.hdf5")[0]

dispfield_filter = itk.TransformToDisplacementFieldFilter[itk.Image[itk.Vector[itk.F, 3], 3], itk.D].New()

dispfield_filter.SetTransform(transform)
dispfield_filter.SetReferenceImage(fixed_image)
dispfield_filter.SetUseReferenceImage(True)

dispfield_filter.Update()

displacement_field = dispfield_filter.GetOutput()

displacement_field.GetLargestPossibleRegion().GetSize()

print(np.array(displacement_field).shape)

warped_moving_image = itk.warp_image_filter(
    moving_image,
    output_origin=fixed_image.GetOrigin(),
    output_direction=fixed_image.GetDirection(),
    output_spacing=fixed_image.GetSpacing(),
    displacement_field=displacement_field)
plt.imshow(itk.checker_board_image_filter(fixed_image, warped_moving_image)[50])

Does this work for your usecase?

HastingsGreer avatar Aug 15 '24 01:08 HastingsGreer

Hello Author, Thank you for your reply. While this did help in getting a vectorfield in the original dimension the dice scores compared to the vectorfiled obtained in size 175,175,175 is very less. Below is the code of how I am calculating dice . I have added comments to clarify my approach to the best of my abilities. I obtain phi_AB from fixed and moving images and then using that phi_AB to warp masks of substructures in my data.

My issue is that the dice I obtained with the phi_AB in the network shape i.e, 175,175,175 is very good but when I use the displacement_field and obtain the warped mask the dice is decreasing drastically.

` # Obtain phi_AB and warped_label according to obtained phi_AB phi_AB,phi_BA,net = get_dvf(fixed_path,moving_path)

# Use the obtained phi_AB to register label mask of a substructure of our data
#The warp_image function is similar to warp_command as per your code, it uses itk.resample_image_filter to return warped image of label.
# I pass original size label and phi_AB(of network size , 175x175x175) as transform to obtain warped image of original size of input
warped_label = warp_image(fixed_label_path, moving_label_path, phi_AB)

# Load and calculate dice
fixed_label = nib.load(fixed_label_path).get_fdata() #path to label 
moving_label = nib.load(moving_label_path).get_fdata()
#dim0: R-L; dim1: A-P; dim2: S-I
warped_label = itk.GetArrayFromImage(warped_label)
#warped_label= np.array(warped_label.cpu())
warped_label = warped_label.swapaxes(0, 2)
warped_label = sitk.GetImageFromArray(warped_label)
fixed_label = sitk.GetImageFromArray(fixed_label)
moving_label = sitk.GetImageFromArray(moving_label)
warped_label = sitk.GetArrayFromImage(warped_label)
warped_label_new = np.zeros(np.shape(warped_label))
warped_label_new[warped_label > 255.0 * 0.5] = 255.0
warped_label = sitk.GetImageFromArray(warped_label_new)
dice_175 = compute_metric_dsc(warped_label, fixed_label,auto_crop = False)
dice_pre = compute_metric_dsc(moving_label, fixed_label,auto_crop = False)

# Solution provided by author to get a displacement field vector in original input size of image.
# Read the label and calculate new displacement field and obtain warped image
moving_label=itk.imread(moving_label_path)
fixed_image=itk.imread(fixed_path)
dispfield_filter = itk.TransformToDisplacementFieldFilter[itk.Image[itk.Vector[itk.F, 3], 3], itk.D].New()

dispfield_filter.SetTransform(phi_AB)
dispfield_filter.SetReferenceImage(fixed_image)
dispfield_filter.SetUseReferenceImage(True)

dispfield_filter.Update()

displacement_field = dispfield_filter.GetOutput()

displacement_field.GetLargestPossibleRegion().GetSize()

# print(np.array(displacement_field).shape)

warped_moving_image = itk.warp_image_filter(
    moving_label,
    output_origin=fixed_image.GetOrigin(),
    output_direction=fixed_image.GetDirection(),
    output_spacing=fixed_image.GetSpacing(),
    displacement_field=displacement_field)

# Dice calculation similar to above, this time using the warped_moving_image
fixed_label = nib.load(fixed_label_path).get_fdata() #path to label 
moving_label = nib.load(moving_label_path).get_fdata()

#dim0: R-L; dim1: A-P; dim2: S-I

warped_label=itk.array_from_image(warped_moving_image)
warped_label = warped_label.swapaxes(0, 2)
warped_label = sitk.GetImageFromArray(warped_label)
fixed_label = sitk.GetImageFromArray(fixed_label)
moving_label = sitk.GetImageFromArray(moving_label)

warped_label = sitk.GetArrayFromImage(warped_label)
warped_label_new = np.zeros(np.shape(warped_label))
warped_label_new[warped_label > 255.0 * 0.5] = 255.0
warped_label = sitk.GetImageFromArray(warped_label_new)
dice_interpolated= compute_metric_dsc(warped_label, fixed_label,auto_crop = False)`

Here is the table of various different samples I tried, PreRegistration Dice is the dice before registering the two masks Dice_175 is the dice after registering the masks through the network without modifying phi_AB Dice_SameSize is the dice after using displacement_field method to obtain the warped masks image

anudeepk17 avatar Aug 15 '24 21:08 anudeepk17

Could you provide the definitions of the functions get_dvf and warp_image? My suspicion is that somewhere in the pipeline the image metadata (spacing, orientation, and origin) is getting lost.

Also, could you provide the output of the following script for the fixed image, the fixed label, the moving image, and the moving label? This will help me understand the image metadata.

I'm sorry that this is taking so much effort to clear up!

import itk
print(itk.imread(fixed_image_path))
print(itk.imread(moving_image_path))
print(itk.imread(fixed_label_path))
print(itk.imread(moving_image_path))

HastingsGreer avatar Aug 16 '24 20:08 HastingsGreer

Hello no issues at all , I am just glad and thankful for your help and guidance. Here is the code you need:

def get_dvf(fixed,moving,save_dvf=None,transform_out=None,fixed_segmentation=None,moving_segmentation=None,io_iterations="None",moving_modality='mri',fixed_modality='mri'):
    ''' fixed               : Path of the fixed image
        moving              : Path of moving image
        save_dvf            : True if want to save dvf as an hdf5 file
        transform_out       : Path of the hdf5 file
        fixed_segmentation  : Path of segmentation map of fixed Image
        moving_segentation  : Path of segmentation map of moving Image
        io_iterations       : Default none, number of iterations.
        moving_modality     : 'ct' or 'mri'
        fixed_modality      : 'ct' or 'mri'
    '''
    net = get_unigradicon()
    fixed = itk.imread(fixed)
    moving = itk.imread(moving)

    if fixed_segmentation is not None:
        fixed_segmentation = itk.imread(fixed_segmentation)
    else:
        fixed_segmentation = None

    if moving_segmentation is not None:
        moving_segmentation = itk.imread(moving_segmentation)
    else:
        moving_segmentation = None

    if io_iterations == "None":
        io_iterations = None
    else:
        io_iterations = int(io_iterations)

    phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair(
        net,
        preprocess(moving, moving_modality, moving_segmentation), 
        preprocess(fixed, fixed_modality, fixed_segmentation), 
        finetune_steps=io_iterations)
    if save_dvf is not None:
        if transform_out is None:
            transform_out="trans.hdf5"
        itk.transformwrite([phi_AB], transform_out)
    return phi_AB,phi_BA,net

def warp_image(moving,fixed,phi_AB=None,transform=None,interpolator=None,save_img=None,warped_moving_out=None):
        '''
        fixed               : Path of the fixed image
        moving              : Path of moving image
        phi_AB              : Phi returned from get_dvf()
        interpolator        : Linear or nearest_neighbor
        save_img            : If want o save image
        warped_moving_out   : Path of the image to be saved.
        transform           : path of hdf5 saved transform 

        '''
        fixed = itk.imread(fixed)
        moving = itk.imread(moving)
        if interpolator=="linear" or interpolator is None:
            interpolator = itk.LinearInterpolateImageFunction.New(moving)
        elif interpolator=="nearest_neighbor":
            interpolator = itk.NearestNeighborInterpolateImageFunction.New(moving)
        else:
            raise Exception("Specify --nearest_neighbor or --linear")
        if transform is not None and phi_AB is None:
             phi_AB = itk.transformread(transform)[0]
        elif phi_AB is None and transform is None:
            raise Exception("Specify either transform path or Phi_AB as returned from get_dvf()")
        interpolator = itk.LinearInterpolateImageFunction.New(moving)
        warped_moving_image = itk.resample_image_filter(
                moving,
                transform=phi_AB,
                interpolator=interpolator,
                use_reference_image=True,
                reference_image=fixed
                )
        if save_img is not None:
            if warped_moving_out is None:
                warped_moving_out="warp.nii.gz"
            itk.imwrite(warped_moving_image, warped_moving_out)
        else:
            return warped_moving_image

The output for the script:

Image (0x6420ee118250)
  RTTI typeinfo:   itk::Image<double, 3u>
  Reference Count: 1
  Modified Time: 241236
  Debug: Off
  Object Name: 
  Observers: 
    none
  Source: (none)
  Source output name: (none)
  Release Data: Off
  Data Released: False
  Global Release Data: Off
  PipelineMTime: 241057
  UpdateMTime: 241235
  RealTimeStamp: 0 seconds 
  LargestPossibleRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  BufferedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  RequestedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  Spacing: [1, 1, 1]
  Origin: [0, 0, 0]
  Direction: 
1 0 0
0 1 0
0 0 1

  IndexToPointMatrix: 
1 0 0
0 1 0
0 0 1

  PointToIndexMatrix: 
1 0 0
0 1 0
0 0 1

  Inverse Direction: 
1 0 0
0 1 0
0 0 1

  PixelContainer: 
    ImportImageContainer (0x6420d41bdfd0)
      RTTI typeinfo:   itk::ImportImageContainer<unsigned long, double>
      Reference Count: 1
      Modified Time: 241232
      Debug: Off
      Object Name: 
      Observers: 
        none
      Pointer: 0x6420e2ce1380
      Container manages memory: true
      Size: 184920
      Capacity: 184920

Image (0x6420d3e46e10)
  RTTI typeinfo:   itk::Image<double, 3u>
  Reference Count: 1
  Modified Time: 241603
  Debug: Off
  Object Name: 
  Observers: 
    none
  Source: (none)
  Source output name: (none)
  Release Data: Off
  Data Released: False
  Global Release Data: Off
  PipelineMTime: 241424
  UpdateMTime: 241602
  RealTimeStamp: 0 seconds 
  LargestPossibleRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  BufferedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  RequestedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  Spacing: [1, 1, 1]
  Origin: [0, 0, 0]
  Direction: 
1 0 0
0 1 0
0 0 1

  IndexToPointMatrix: 
1 0 0
0 1 0
0 0 1

  PointToIndexMatrix: 
1 0 0
0 1 0
0 0 1

  Inverse Direction: 
1 0 0
0 1 0
0 0 1

  PixelContainer: 
    ImportImageContainer (0x6420d41bdfd0)
      RTTI typeinfo:   itk::ImportImageContainer<unsigned long, double>
      Reference Count: 1
      Modified Time: 241599
      Debug: Off
      Object Name: 
      Observers: 
        none
      Pointer: 0x6420e2ce1380
      Container manages memory: true
      Size: 184920
      Capacity: 184920

Image (0x642014a483d0)
  RTTI typeinfo:   itk::Image<double, 3u>
  Reference Count: 1
  Modified Time: 241970
  Debug: Off
  Object Name: 
  Observers: 
    none
  Source: (none)
  Source output name: (none)
  Release Data: Off
  Data Released: False
  Global Release Data: Off
  PipelineMTime: 241791
  UpdateMTime: 241969
  RealTimeStamp: 0 seconds 
  LargestPossibleRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  BufferedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  RequestedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  Spacing: [1, 1, 1]
  Origin: [0, 0, 0]
  Direction: 
1 0 0
0 1 0
0 0 1

  IndexToPointMatrix: 
1 0 0
0 1 0
0 0 1

  PointToIndexMatrix: 
1 0 0
0 1 0
0 0 1

  Inverse Direction: 
1 0 0
0 1 0
0 0 1

  PixelContainer: 
    ImportImageContainer (0x6420d41bdfd0)
      RTTI typeinfo:   itk::ImportImageContainer<unsigned long, double>
      Reference Count: 1
      Modified Time: 241966
      Debug: Off
      Object Name: 
      Observers: 
        none
      Pointer: 0x6420e2ce1380
      Container manages memory: true
      Size: 184920
      Capacity: 184920

Image (0x6420ee118250)
  RTTI typeinfo:   itk::Image<double, 3u>
  Reference Count: 1
  Modified Time: 242337
  Debug: Off
  Object Name: 
  Observers: 
    none
  Source: (none)
  Source output name: (none)
  Release Data: Off
  Data Released: False
  Global Release Data: Off
  PipelineMTime: 242158
  UpdateMTime: 242336
  RealTimeStamp: 0 seconds 
  LargestPossibleRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  BufferedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  RequestedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [40, 67, 69]
  Spacing: [1, 1, 1]
  Origin: [0, 0, 0]
  Direction: 
1 0 0
0 1 0
0 0 1

  IndexToPointMatrix: 
1 0 0
0 1 0
0 0 1

  PointToIndexMatrix: 
1 0 0
0 1 0
0 0 1

  Inverse Direction: 
1 0 0
0 1 0
0 0 1

  PixelContainer: 
    ImportImageContainer (0x6420d41bdfd0)
      RTTI typeinfo:   itk::ImportImageContainer<unsigned long, double>
      Reference Count: 1
      Modified Time: 242333
      Debug: Off
      Object Name: 
      Observers: 
        none
      Pointer: 0x6420e2ce1380
      Container manages memory: true
      Size: 184920
      Capacity: 184920

Thank you again for your prompt responses and help. I look forward to your reply.

anudeepk17 avatar Aug 18 '24 18:08 anudeepk17

This is a real puzzler! I see three possibilities for what is going on:

  1. It is generally best practice to warp label images using itk.NearestNeighborInterpolateImageFunction instead of itk.LinearInterpolateImageFunction. It is worth converting both paths (transform and displacement field) to make sure that they use Nearest Neighbor interpolation- and maybe this would fix the discrepancy, maybe it would not

  2. Examining your metadata, I realized that we have not extensively tested our approach on registering images with resolution much lower than 175 x 175 x 175. It is possible that this case is exposing a bug in our code- or the model is producing a very high resolution displacement field that is somehow "cheating", hiding labels in between the low resolution pixels, and the "cheating" is defeated by forcing the displacement field to have the same resolution as the image

  3. There are some confusing elements in the code you have posted- in particular, converting between numpy arrays and itk images is tricky, and the calls to swapaxis are easy to mess up. Also, it is not clear how the code posted works together- the call to warp_image in the first code sample leaves interpolator as None, but the code for warp_image throws an error if interpolator is None.

Would you be willing to email me full runnable code and a pair of images from your table? I know that this may be a data sharing issue, but I have reached a dead end with the information I have. I think I need an example I can run and experiment on to resolve this issue.

HastingsGreer avatar Aug 22 '24 13:08 HastingsGreer

As I understand it, the transform created by Unigradicon consists of composite: affine, DVF, and affine transforms. The outer affines project fixed image points into network coordinates (spacing: 1, offset: 0), and then into moving space. I think using this transform directly to move images from moving -> fixed space (one resampling) is better than resampling the DVF to original image spacing and then resampling the moving image (two resamplings). Unless there's some other purpose for having the DVF in the fixed image spacing? Like visualisation or calculating Jacobian.

clarkbab avatar May 14 '25 03:05 clarkbab

In any case, there are some tricky things when working with ITK. I'd take a look at these and see if any apply:

  • There's a difference between how nibabel and ITK load nifti images. ITK assumes the images it loads are in RAS+ coordinates, and wants to convert them to LPS+ - so it sets -1 for x/y axes in the direction matrix and sets negative x/y offsets. Nibabel doesn't flip the directions (affine matrix in nibabel) or offsets.
  • Unigradicon uses ITK to load nifti files so the resulting transform expects images with reversed x/y directions and offsets. I think you're using ITK to load images before feeding them into the transform so shouldn't be a problem here.
  • ITK always transposes numpy arrays when calling 'GetArrayFromImage' and sometimes when calling 'GetImageFromArray' (depending on whether you've transposed the numpy array previously and changed the indexing style - see C vs. Fortran indexing). I use a numpy array transpose and copy before calling 'GetImageFromArray' to ensure that the ITK transpose is nullified regardless of what I did with the numpy array in all preceding code.
  • I would suggest writing separate methods for numpy/ITK conversion so that you don't have to remember to transpose all the time when converting.
  • When debugging this stuff I found it useful to apply the full transform to fixed image points (e.g. origin) and see if they ended up in reasonable positions in the moving image.

clarkbab avatar May 14 '25 04:05 clarkbab