vision icon indicating copy to clipboard operation
vision copied to clipboard

procrustes alignment for pytorch

Open heth27 opened this issue 1 year ago • 3 comments

🚀 The feature

Orthogonal procrustes alignment

Motivation, pitch

Procrustes alignment is a staple when calculating metrics for 3d human pose estimation, but there seems to be no library that offers this function for pytorch, so I guess everyone just maintains their own version.

There is a variant in scipy https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.procrustes.html

Alternatives

No response

Additional context

The implementation I'm using, don't know if it is any good.

def procrustes(pts1: torch.Tensor, pts2: torch.Tensor):
    assert pts1.shape == pts2.shape, f"{pts1.shape} != {pts2.shape}"
    assert pts1.shape[-1] == 3 and len(pts1.shape) == 2, f"{pts1.shape}"
    # estimate a sim3 transformation to align two point clouds
    # find M = argmin ||P1 - M @ P2||
    t1 = pts1.mean(dim=0)
    t2 = pts2.mean(dim=0)
    pts1 = pts1 - t1[None, :]
    pts2 = pts2 - t2[None, :]

    s1 = pts1.square().sum(dim=-1).mean().sqrt()
    s2 = pts2.square().sum(dim=-1).mean().sqrt()
    pts1 = pts1 / s1
    pts2 = pts2 / s2
    try:

        U, _, V = (pts1.T @ pts2).double().svd()
        U: torch.Tensor = U
        V: torch.Tensor = V
    except:
        print("Procustes failed: SVD did not converge!")
        s = s1 / s2
        return 1, torch.eye(3, device=pts1.device), torch.zeros_like(t1)
    # build rotation matrix
    R = (U @ V.T).float()
    if R.det() < 0:
        R[:, 2] *= -1
    s = s1 / s2
    t = t1 - s * t2 @ R.T

    # use as mat4: [sR, t] @ pts2
    # or as s * R @ pts2 + t

    # s, R, mean_1, mean_2 = procrustes(pts1, pts2)
    #
    # procrustes_aligned = torch.einsum("jd, od -> jo", coords3d_pred_rel_dataset_format[index_in_batch] - mean_2,
    #                                               s * R) + mean_1
    return s, R, t1, t2

example usage:

s, R, mean_1, mean_2 = procrustes(coords_3d_true,
                                              coords_3d_prediction)
procrustes_aligned = torch.einsum("jd, od -> jo", coords_3d_prediction - mean_2,
                                              s * R) + mean_1

heth27 avatar Aug 09 '24 15:08 heth27

If this is better suited for e.g. torchmetrics (https://lightning.ai/docs/torchmetrics/stable/) this would also be good to know

heth27 avatar Aug 11 '24 23:08 heth27

Hi @heth27 and thank you for the feature request. Torchvision doesn't really have a holistic support for 3D data in general, so I'm not sure procrustes alignement would be in scope. We typically add such metrics when they directly relate to one of the CV tasks that torchvision supports (classification, detection, etc.), but 3D human pose is not yet in scope. Thank you for providing a snippet, I hope it can be useful to users looking for this exact feature.

NicolasHug avatar Aug 12 '24 10:08 NicolasHug

If this is better suited for e.g. torchmetrics (lightning.ai/docs/torchmetrics/stable) this would also be good to know

It might be in scope for torchmetrics, although note that this isn't owned by the pytorch org, so we don't have any weight in the decision process over there.

NicolasHug avatar Aug 12 '24 10:08 NicolasHug