keymorph icon indicating copy to clipboard operation
keymorph copied to clipboard

Robust multimodal brain registration via keypoints

KeyMorph: Robust Multi-modal Registration via Keypoint Detection

KeyMorph is a deep learning-based image registration framework that relies on automatically extracting corresponding keypoints.

Updates

  • [Apr 2024] Releasing foundational model of KeyMorph for brain MRIs which is trained on over 100K images at full resolution (256^3). Instructions under "Foundation model".
  • [Dec 2023] Journal paper extension of MIDL paper published in Medical Image Analysis. Instructions under "IXI-trained, half-resolution models".
  • [Feb 2022] Conference paper published in MIDL 2021.

Requirements

Install the packages with pip install -r requirements.txt.

You might need to install Pytorch separately, according to your GPU and CUDA version. Install Pytorch here.

Downloading Trained Weights

You can find all trained weights under Releases. Download them and put them in the ./weights/ folder.

Registering brain volumes

Foundation model

The foundation model is trained on over 100,000 brain MR images at full resolution (256x256x256). The script will automatically min-max normalize the images and resample to 1mm isotropic resolution.

To register a single pair of volumes:

python scripts/register.py \
    --registration_model keymorph \
    --num_keypoints 256 \
    --backbone truncatedunet \
    --moving ./example_data/images/IXI_000001_0000.nii.gz \
    --fixed ./example_data/images/IXI_000002_0000.nii.gz \
    --load_path ./weights/foundation-numkey256-256x256x256.tar \
    --moving_seg ./example_data/labels/IXI_000001_0000.nii.gz \
    --fixed_seg ./example_data/labels/IXI_000002_0000.nii.gz \
    --list_of_aligns affine tps_0 \
    --list_of_metrics mse harddice \
    --save_eval_to_disk \
    --visualize

Description of important flags:

  • --moving and --fixed are paths to moving and fixed images.
  • --moving_seg and --fixed_seg are optional, but are required if you want the script to report Dice scores.
  • --list_of_aligns specifies the types of alignment to perform. Options are rigid, affine and tps_<lambda> (TPS with hyperparameter value equal to lambda). lambda=0 corresponds to exact keypoint alignment. lambda=10 is very similar to affine.
  • --list_of_metrics specifies the metrics to report. Options are mse, harddice, softdice, hausd, jdstd, jdlessthan0. To compute Dice scores and surface distances, --moving_seg and --fixed_seg must be provided.
  • --save_eval_to_disk saves all outputs to disk. The default location is ./register_output/.
  • --visualize plots a matplotlib figure of moving, fixed, and registered images overlaid with corresponding points.

You can also replace filenames with directories to register all images in the directory. Note that the script expects corresponding image and segmentation pairs to have the same filename.

python scripts/register.py \
    --registration_model keymorph \
    --num_keypoints 256 \
    --backbone truncatedunet \
    --moving ./example_data/images/ \
    --fixed ./example_data/images/ \
    --load_path ./weights/foundation-numkey256-256x256x256.tar \
    --moving_seg ./example_data/labels/ \
    --fixed_seg ./example_data/labels/ \
    --list_of_aligns affine tps_0 \
    --list_of_metrics mse harddice \
    --save_eval_to_disk \
    --visualize

IXI-trained, half-resolution models

All other model weights are trained on half-resolution (128x128x128) on the (smaller) IXI dataset. The script will automatically min-max normalize the images. To register two volumes with our best-performing model:

python scripts/register.py \
    --half_resolution \
    --registration_model keymorph \
    --num_keypoints 512 \
    --backbone conv \
    --moving ./example_data/images_half/IXI_001_128x128x128.nii.gz \
    --fixed ./example_data/images_half/IXI_002_128x128x128.nii.gz \
    --load_path ./weights/numkey512_tps0_dice.4760.h5 \
    --moving_seg ./example_data/labels_half/IXI_001_128x128x128.nii.gz \
    --fixed_seg ./example_data/labels_half/IXI_002_128x128x128.nii.gz \
    --list_of_aligns affine tps_0 \
    --list_of_metrics mse harddice \
    --save_eval_to_disk \
    --visualize

TLDR in code

The crux of the code is in the forward() function in keymorph/model.py, which performs one forward pass through the entire KeyMorph pipeline.

Here's a pseudo-code version of the function:

def forward(img_f, img_m, seg_f, seg_m, network, optimizer, kp_aligner):
    '''Forward pass for one mini-batch step. 
    Variables with (_f, _m, _a) denotes (fixed, moving, aligned).
    
    Args:
        img_f, img_m: Fixed and moving intensity image (bs, 1, l, w, h)
        seg_f, seg_m: Fixed and moving one-hot segmentation map (bs, num_classes, l, w, h)
        network: Keypoint extractor network
        kp_aligner: Rigid, affine or TPS keypoint alignment module
    '''
    optimizer.zero_grad()

    # Extract keypoints
    points_f = network(img_f)
    points_m = network(img_m)

    # Align via keypoints
    grid = kp_aligner.grid_from_points(points_m, points_f, img_f.shape, lmbda=lmbda)
    img_a, seg_a = utils.align_moving_img(grid, img_m, seg_m)

    # Compute losses
    mse = MSELoss()(img_f, img_a)
    soft_dice = DiceLoss()(seg_a, seg_f)

    if unsupervised:
        loss = mse
    else:
        loss = soft_dice

    # Backward pass
    loss.backward()
    optimizer.step()

The network variable is a CNN with center-of-mass layer which extracts keypoints from the input images. The kp_aligner variable is a keypoint alignment module. It has a function grid_from_points() which returns a flow-field grid encoding the transformation to perform on the moving image. The transformation can either be rigid, affine, or nonlinear (TPS).

Training KeyMorph

Use scripts/run.py to train KeyMorph. Some example bash commands are provided in bash_scripts/.

I'm in the process of updating the code to make it more user-friendly, and will update this repository soon. In the meantime, feel free to open an issue if you have any training questions.

Issues

This repository is being actively maintained. Feel free to open an issue for any problems or questions.

Legacy code

For a legacy version of the code, see our legacy branch.

References

If this code is useful to you, please consider citing our papers. The first conference paper contains the unsupervised, affine version of KeyMorph. The second, follow-up journal paper contains the unsupervised/supervised, affine/TPS versions of KeyMorph.

Evan M. Yu, et al. "KeyMorph: Robust Multi-modal Affine Registration via Unsupervised Keypoint Detection." (MIDL 2021).

Alan Q. Wang, et al. "A Robust and Interpretable Deep Learning Framework for Multi-modal Registration via Keypoints." (Medical Image Analysis 2023).