keymorph icon indicating copy to clipboard operation
keymorph copied to clipboard

bug with batch_size > 1 ?

Open etienne87 opened this issue 2 months ago • 8 comments

i am testing your code like this:

import torch
import numpy as np

from keymorph.keypoint_aligners import RigidKeypointAligner

bs, num_points = 8, 10
points_m = torch.rand(bs, num_points, 3)
points_f = torch.rand(bs, num_points, 3)

keypoint_aligner = RigidKeypointAligner(
                    points_m=points_m,
                    points_f=points_f,
                    dim=3
                )
snakes) eperot@frbucaw06dl:~/registration/monai-registration> python test_rigid_kp_aligner.py 
Traceback (most recent call last):
  File "/home/eperot/registration/monai-registration/test_rigid_kp_aligner.py", line 9, in <module>
    keypoint_aligner = RigidKeypointAligner(
                       ^^^^^^^^^^^^^^^^^^^^^
  File "/home/eperot/.conda/envs/snakes/lib/python3.11/site-packages/keymorph/keypoint_aligners.py", line 68, in __init__
    inverse_transform_matrix = self._square(
                               ^^^^^^^^^^^^^
  File "/home/eperot/.conda/envs/snakes/lib/python3.11/site-packages/keymorph/transformations.py", line 34, in _square
    square[:, : self.dim, : self.dim + 1] = matrix
    ~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The expanded size of the tensor (1) must match the existing size (8) at non-singleton dimension 0.  Target sizes: [1, 3, 4].  Tensor sizes: [8, 3, 4]

the "_square" function inside "AffineTransform" is wrong i think

# this code assumes batch size of 1 
 def _square(self, matrix):
        square = torch.eye(self.dim + 1)[None]
        square[:, : self.dim, : self.dim + 1] = matrix
        return square

you can replace with

def _square(self, matrix):
    batch_size = matrix.shape[0]
    square = torch.eye(self.dim + 1).unsqueeze(0).repeat(batch_size, 1, 1)
    square[:, : self.dim, : self.dim + 1] = matrix
    return square

and later if i call affine_grid it also fails.

is the code hardcoded for a batch_size of 1 ? if yes, isn't this a problem a training time ?

Anyway, a simple fix can be made using F.affine_grid and respect your convention (matrix is obtained from ijk points)

def zyx_to_xyz_affine(affine_zyx):
    # Permutation matrix: ZYX -> XYZ (reverse the axes)
    P = torch.Tensor([
        [0, 0, 1, 0],
        [0, 1, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 0, 1]
    ]).to(affine_zyx)

    # Convert: P @ affine_zyx @ P
    affine_xyz = P @ affine_zyx @ P
    return affine_xyz

 def get_flow_field(self, grid_shape):
        matrix = zyx_to_xyz_affine(self.inverse_transform_matrix)
        grid = torch.nn.functional.affine_grid(matrix[:, :3], grid_shape, align_corners=True)
        return grid

EDIT: i found another simpler, less intrusive fix

def get_inverse_transformed_points(self, points):
        """Transforms a set of points in fixed space to moving space using the fitted matrix.
        If align_in_real_world_coords is False, computes:
            p_m = A p_f.
        If align_in_real_world_coords is True, points must be in voxel coordinates, computes:
            p_m = A_m^-1 A^-1 A_f p_f.
        """
        batch_size, num_points, _ = points.shape
        transform_matrix = self.inverse_transform_matrix[:, :-1, :]

        # Transform real-world coordinates to the moving image space using the registration affine
        # Convert to homogeneous coordinates
        ones = torch.ones(batch_size, num_points, 1).to(points.device)
        points = torch.cat([points, ones], dim=2)
        # this is wrong, does not work for batch_size > 1
        # points = torch.bmm(transform_matrix, points.permute(0, 2, 1)).permute(0, 2, 1)
        points = torch.einsum('brc,bpc->bpr', transform_matrix, points)

        return points

etienne87 avatar Oct 16 '25 07:10 etienne87