Update a Subject's affine matrix during RandomAffine transformation
🚀 Feature
When applying the RandomAffine transformation to a Subject, the affine matrix should be updated.
Motivation
While the affine matrix is often not used in the training of a model, there are certain cases in which it might be. For example, if we want to train a model which is conditional on the voxel spacing, then we must track the voxel spacing through any data augmentation which we apply (of course, non-linear transforms do not allow this, but affine transforms do). The simplest and most versatile way to achieve this is to update the affine matrix after each spatial augmentation transformation (like is already done in the Resample preprocessing transform).
Pitch
Improve the RandomAffine class so as to update the affine matrix of the subject under any randomly sampled combination of scaling, rotation, and translation.
Alternatives
Here is my attempt. As far as I can tell, it does the job. I have named my classes, which inherit from RandomAffine and Affine as MyRandomAffine and MyAffine.
The important changes are in the MyAffine class, which is called in the apply_transform method of the MyRandomAffine transform. Specifically, the new method get_new_affine_matrix, called at the end of the apply_transform method of MyAffine. In short, the get_new_affine_matrix works like so:
- Get the randomly sampled scaling, rotation and translation parameters (which were used in the SimpleITK transformations already by this point)
- The SimpleITK transformations map the original voxel indices to new voxel indices. We need to find the point about which this mapping occurs, and set it to be the origin in voxels. We create the
offset_voxelsmatrix for this. - We then calculate the forward transformation that was done by the SimpleITK transforms. The first transformation applied is the rotation and translation, then the scaling happens afterwards. Notice how we pre and post multiply the transformation matrices by the
resest_voxelsandoffset_voxelsmatrices respectively (strictly I think this may not be necessary for the rotation and translation transform), because we want to perform these transformations about the same point that SimpleITK did, but we then want to recover our original coordinates, since they are what the original affine matrix was based upon. - We get the backward transformation as the inverse of the forward transformation and pre multiply it by the original affine. This now gives us a mapping from the output voxel indices to world space (i.e. this is our new affine matrix).
Here is the code:
import numpy as np
import torchio as tio
import torch
from torchio import Subject
from torchio.constants import TYPE, INTENSITY
from torchio.transforms.augmentation.spatial.random_affine import RandomAffine, Affine, get_borders_mean
from torchio.data.io import nib_to_sitk
from numbers import Number
def get_pixdim_from_affine(affine: np.array):
rot = affine[:-1,:-1]
return np.sqrt(np.sum(rot**2, axis=0))
class MyRandomAffine(RandomAffine):
def __init__(self, update_affine: bool = False, **kwargs):
super().__init__(**kwargs)
self.update_affine = update_affine
def apply_transform(self, subject: Subject) -> Subject:
scaling_params, rotation_params, translation_params = self.get_params(
self.scales,
self.degrees,
self.translation,
self.isotropic,
)
arguments = {
'scales': scaling_params.tolist(),
'degrees': rotation_params.tolist(),
'translation': translation_params.tolist(),
'center': self.center,
'default_pad_value': self.default_pad_value,
'image_interpolation': self.image_interpolation,
'label_interpolation': self.label_interpolation,
'check_shape': self.check_shape,
}
transform = MyAffine(update_affine=self.update_affine, **self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
class MyAffine(Affine):
def __init__(self, update_affine: bool = False, **kwargs):
super().__init__(**kwargs)
self.update_affine = update_affine
def apply_transform(self, subject: Subject) -> Subject:
if self.check_shape:
subject.check_consistent_spatial_shape()
default_value: float
for image in self.get_images(subject):
transform = self.get_affine_transform(image)
transformed_tensors = []
for tensor in image.data:
sitk_image = nib_to_sitk(
tensor[np.newaxis],
image.affine,
force_3d=True,
)
if image[TYPE] != INTENSITY:
interpolation = self.label_interpolation
default_value = 0
else:
interpolation = self.image_interpolation
if self.default_pad_value == 'minimum':
default_value = tensor.min().item()
elif self.default_pad_value == 'mean':
default_value = get_borders_mean(
sitk_image,
filter_otsu=False,
)
elif self.default_pad_value == 'otsu':
default_value = get_borders_mean(
sitk_image,
filter_otsu=True,
)
else:
assert isinstance(self.default_pad_value, Number)
default_value = float(self.default_pad_value)
transformed_tensor = self.apply_affine_transform(
sitk_image,
transform,
interpolation,
default_value,
)
transformed_tensors.append(transformed_tensor)
image.set_data(torch.stack(transformed_tensors))
if self.update_affine:
new_affine = self.get_new_affine_matrix(image)
image.affine = new_affine
return subject
def get_new_affine_matrix(self, image: tio.Image) -> np.ndarray:
# get the scaling, rotation and translation parameters
scaling = np.asarray(self.scales).copy()
rotation = np.asarray(self.degrees).copy()
translation = np.asarray(self.translation).copy()
# get the original affine matrix
original_affine = image.affine
# get matrix to offset voxel indices so that the voxel origin ([0,0,0] in voxel space)
# is at the location about which the transformation is applied
if self.center == "image":
voxel_origin = np.array(image.spatial_shape) / 2
elif self.center == "origin":
voxel_origin = (np.linalg.inv(original_affine) @ np.array([[0, 0, 0, 1]]).T).flatten()[:3]
offset_voxels = np.eye(4)
offset_voxels[:3,-1] = -voxel_origin
reset_voxels = np.linalg.inv(offset_voxels)
# forward transform of voxels
rot = self.get_rotation_matrix(rotation)
trans = translation / get_pixdim_from_affine(original_affine) # convert translation to voxel units
rot_trans_mat = np.hstack([rot, trans.reshape(-1,1)])
rot_trans_mat = np.vstack([rot_trans_mat, [0, 0, 0, 1]])
rot_trans_mat = reset_voxels @ rot_trans_mat @ offset_voxels
scale = np.eye(4)
scale[:3,:3] = np.diag(scaling)
scale_mat = reset_voxels @ scale @ offset_voxels
forward_voxel_transform = scale_mat @ rot_trans_mat
# inverse transform of voxels
backward_voxel_transform = np.linalg.inv(forward_voxel_transform)
new_affine = original_affine @ backward_voxel_transform
return new_affine
@staticmethod
def get_rotation_matrix(deg):
x_deg, y_deg, z_deg = deg
x_mat = np.array([[1, 0, 0], [0, np.cos(np.deg2rad(x_deg)), -np.sin(np.deg2rad(x_deg))], [0, np.sin(np.deg2rad(x_deg)), np.cos(np.deg2rad(x_deg))]])
y_mat = np.array([[np.cos(np.deg2rad(y_deg)), 0, np.sin(np.deg2rad(y_deg))], [0, 1, 0], [-np.sin(np.deg2rad(y_deg)), 0, np.cos(np.deg2rad(y_deg))]])
z_mat = np.array([[np.cos(np.deg2rad(z_deg)), -np.sin(np.deg2rad(z_deg)), 0], [np.sin(np.deg2rad(z_deg)), np.cos(np.deg2rad(z_deg)), 0], [0, 0, 1]])
return x_mat @ y_mat @ z_mat
Here is some code to visually confirm that it works:
import matplotlib.pyplot as plt
subject_in = tio.datasets.Colin27()
# Play around with these params
transform = MyRandomAffine(scales=(2,2,1,1,1,1), degrees=(0,0,0,0,47,47), translation=(8,8,-30,-30,0,0), update_affine=True)
# Do the transform
subject_out = transform(subject_in)
# Get original and new affines
original_affine = subject_in["t1"].affine
new_affine = subject_out["t1"].affine
print("Original affine")
print(np.round(original_affine,4), "\n")
print("New affine")
print(np.round(new_affine,4))
# Set point of interest in voxel indices of original image
index1 = np.array([[63, 75, 90, 1]]).T
world = original_affine @ index1 # original index to world
# Get index in transformed image
index2 = np.linalg.inv(new_affine) @ world # world to new index
############################## Plot ##############################
fig, ax = plt.subplots(1,2, figsize=(20,10))
### Image in
original_origin_index = np.linalg.inv(original_affine) @ np.array([[0,0,0,1]]).T
# Image data
image1 = subject_in["t1"]["data"][0, :, :, index1[2][0]]
ax[0].imshow(image1,cmap="gray")
# Point of interest
ax[0].scatter(index1[1][0], index1[0][0], s=20, c='red', marker='o')
ax[0].text(index1[1][0] + 15, index1[0][0] - 15, f"{index1[0][0]}, {index1[1][0]}", fontsize=12, color='red')
# Origin
ax[0].scatter(original_origin_index[1][0], original_origin_index[0][0], s=20, c='green', marker='o')
ax[0].text(original_origin_index[1][0] + 15, original_origin_index[0][0] - 15, f"{original_origin_index[0][0]}, {original_origin_index[1][0]}", fontsize=12, color='green')
# Formatting
ax[0].minorticks_on()
ax[0].xaxis.set_minor_locator(plt.MultipleLocator(5))
ax[0].yaxis.set_minor_locator(plt.MultipleLocator(5))
ax[0].grid(which='both', color='blue', linestyle='-', linewidth=0.5)
ax[0].grid(which='minor', color='cyan', linestyle='-', linewidth=0.5, alpha=0.25)
### Image out
new_origin_index = np.linalg.inv(new_affine) @ np.array([[0,0,0,1]]).T
# Image data
image2 = subject_out["t1"]["data"][0, :, :, int(index2[2][0])]
ax[1].imshow(image2 ,cmap="gray")
# Point of interest
ax[1].scatter(index2[1][0], index2[0][0], s=20, c='red', marker='o')
ax[1].text(index2[1][0] + 15, index2[0][0] - 15, f"{index2[0][0]:.2f}, {index2[1][0]:.2f}", fontsize=12, color='red')
# Origin
ax[1].scatter(new_origin_index[1][0], new_origin_index[0][0], s=20, c='green', marker='o')
ax[1].text(new_origin_index[1][0] + 15, new_origin_index[0][0] - 15, f"{new_origin_index[0][0]:.2f}, {new_origin_index[1][0]:.2f}", fontsize=12, color='green')
# Formatting
ax[1].minorticks_on()
ax[1].xaxis.set_minor_locator(plt.MultipleLocator(5))
ax[1].yaxis.set_minor_locator(plt.MultipleLocator(5))
ax[1].grid(which='both', color='blue', linestyle='-', linewidth=0.5)
ax[1].grid(which='minor', color='cyan', linestyle='-', linewidth=0.5, alpha=0.25)
plt.show()
Which outputs this (point of interest in red and origin in green):
Additional context
Let me just say that I absolutely love the TorchIO library and use it in my pipelines wherever possible!
I think another feature missing from RandomAffine transform is the option to input scales as target voxel spacing, instead of scaling factors. This would be very easy to implement, just a couple of lines of code. I didn't think this was worth raising a separate issue, but I thought I would mention it since it is something that I use for my pipeline.
Hello thanks for the proposition, this looks interesting, but I am not sure if it won't break other application
I guess it is a question of application. The way I use randomAffine is to simulate different subject's size or orientation. For this I want the volume affine to stay identical because it is only the content that change (ie a different view point ...)
if we do what you propose, then we will end up with subject having different affine (and may be different matrix size) which is a problem for concatening subject in a batch... this is the main limitation I see. Don't you have any problem with handling subjects with different affine ?
This being said, I can understand the need. An other application where it should be useful too, is if one want to learn a coregistration task. In this case the exact Affine applied to the subject (by randomAffine) need to be kept. But I would prefer to keep this information in subject history instead of changing the subject affine (for the previous reason). This is actually already the case, since we store the input parameters choosen by the random process (scale rot and trans) so we can reconstruct the affine transformation (as you do in your code)
I hope it make sense
Hi, thanks for your reply.
Regarding breaking applications, I think it could be done in a way that it wouldn't. I would suggest an argument update_affine defaulting to False.
I agree that it is indeed a question of application and I think that your usage is the most typical case, hence I think it should be optional.
The matrix size would not change. The only thing that changes in my code is the affine attribute of each tio.ScalarImage or tio.LabelMap in the subject. I think that it is only a problem if multiple images within a single subject have different affines, but different subject are safe, if not expected, to have different affines.
Yes I am sure it would be useful for coregistration tasks in some way, not that I have worked on these.
Ah, I did not realise that the sampled parameters were kept in the history! I guess this means that one could implement my code post transformation if they desired to have the updated affine. I think it would probably be easier for most users to just have an option to update the affine during the transformation though. However, there is then a problem, because if a user does try to recover the affine themselves from the subject's history, then it will be wrong, since we have updated it! This would take some care on the user's part. Perhaps a separate function which takes in a subject and updates the subject's affine from its history, before optionally modifying its history to remove the transform?