keymorph
keymorph copied to clipboard
bug with batch_size > 1 ?
i am testing your code like this:
import torch
import numpy as np
from keymorph.keypoint_aligners import RigidKeypointAligner
bs, num_points = 8, 10
points_m = torch.rand(bs, num_points, 3)
points_f = torch.rand(bs, num_points, 3)
keypoint_aligner = RigidKeypointAligner(
points_m=points_m,
points_f=points_f,
dim=3
)
snakes) eperot@frbucaw06dl:~/registration/monai-registration> python test_rigid_kp_aligner.py
Traceback (most recent call last):
File "/home/eperot/registration/monai-registration/test_rigid_kp_aligner.py", line 9, in <module>
keypoint_aligner = RigidKeypointAligner(
^^^^^^^^^^^^^^^^^^^^^
File "/home/eperot/.conda/envs/snakes/lib/python3.11/site-packages/keymorph/keypoint_aligners.py", line 68, in __init__
inverse_transform_matrix = self._square(
^^^^^^^^^^^^^
File "/home/eperot/.conda/envs/snakes/lib/python3.11/site-packages/keymorph/transformations.py", line 34, in _square
square[:, : self.dim, : self.dim + 1] = matrix
~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The expanded size of the tensor (1) must match the existing size (8) at non-singleton dimension 0. Target sizes: [1, 3, 4]. Tensor sizes: [8, 3, 4]
the "_square" function inside "AffineTransform" is wrong i think
# this code assumes batch size of 1
def _square(self, matrix):
square = torch.eye(self.dim + 1)[None]
square[:, : self.dim, : self.dim + 1] = matrix
return square
you can replace with
def _square(self, matrix):
batch_size = matrix.shape[0]
square = torch.eye(self.dim + 1).unsqueeze(0).repeat(batch_size, 1, 1)
square[:, : self.dim, : self.dim + 1] = matrix
return square
and later if i call affine_grid it also fails.
is the code hardcoded for a batch_size of 1 ? if yes, isn't this a problem a training time ?
Anyway, a simple fix can be made using F.affine_grid and respect your convention (matrix is obtained from ijk points)
def zyx_to_xyz_affine(affine_zyx):
# Permutation matrix: ZYX -> XYZ (reverse the axes)
P = torch.Tensor([
[0, 0, 1, 0],
[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 1]
]).to(affine_zyx)
# Convert: P @ affine_zyx @ P
affine_xyz = P @ affine_zyx @ P
return affine_xyz
def get_flow_field(self, grid_shape):
matrix = zyx_to_xyz_affine(self.inverse_transform_matrix)
grid = torch.nn.functional.affine_grid(matrix[:, :3], grid_shape, align_corners=True)
return grid
EDIT: i found another simpler, less intrusive fix
def get_inverse_transformed_points(self, points):
"""Transforms a set of points in fixed space to moving space using the fitted matrix.
If align_in_real_world_coords is False, computes:
p_m = A p_f.
If align_in_real_world_coords is True, points must be in voxel coordinates, computes:
p_m = A_m^-1 A^-1 A_f p_f.
"""
batch_size, num_points, _ = points.shape
transform_matrix = self.inverse_transform_matrix[:, :-1, :]
# Transform real-world coordinates to the moving image space using the registration affine
# Convert to homogeneous coordinates
ones = torch.ones(batch_size, num_points, 1).to(points.device)
points = torch.cat([points, ones], dim=2)
# this is wrong, does not work for batch_size > 1
# points = torch.bmm(transform_matrix, points.permute(0, 2, 1)).permute(0, 2, 1)
points = torch.einsum('brc,bpc->bpr', transform_matrix, points)
return points