pytorch3d icon indicating copy to clipboard operation
pytorch3d copied to clipboard

Norm of output from `pytorch3d.transforms.rotation_conversions.matrix_to_axis_angle` is not in [0,pi]

Open davidcjuergens opened this issue 3 years ago • 2 comments

I'm noticing that the output of a conversion between a set of rotation matrices to their axis-angle form results in axis-angle vectors which have a norm that is greater than pi, which seems not good. I'm using a modestly altered version of rotation_conversions.py for which I show a diff against Meta's original.

My questions are: (A) is this behavior expected? (B) If it is expected, what is the best way to normalize the magnitude of the vector such that they are always between [0,pi]

Thanks in advance, any help would appreciated!!

Instructions To Reproduce the Issue:

  1. Here is a diff between my rotation_conversions.py which I am using to reproduce this behavior, and what is the original from Meta (meta_rotation_conversions.py).

I created meta_rotation_conversions.py by copy/pasting the code from here: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html into a file I called tmp.py and then performing sed s:"\[docs\]":'':g tmp.py > meta_rotation_conversions.py

diff rotation_conversions.py meta_rotation_conversions.py 
11a12,13
> from ..common.datatypes import Device
> 
51d52
<     # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
70a72
> 
133,134d134
<             # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
<             #  `int`.
136,137d135
<             # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
<             #  `int`.
139,140d136
<             # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
<             #  `int`.
142,143d137
<             # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
<             #  `int`.
158c152
<         F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
---
>         F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :  # pyre-ignore[16]
161a156
> 
220a216
> 
305a302
> 
307c304
<     n: int, dtype: Optional[torch.dtype] = None, device = None
---
>     n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
324d320
<     # pyre-fixme[6]: For 2nd param expected `dtype` but got `Optional[dtype]`.
330a327
> 
332c329
<     n: int, dtype: Optional[torch.dtype] = None, device = None
---
>     n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
349a347
> 
351c349
<     dtype: Optional[torch.dtype] = None, device = None
---
>     dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
366a365
> 
381a381
> 
402a403
> 
419a421
> 
436a439
> 
459a463
> 
475a480
> 
491a497
> 
523a530
> 
554a562
> 
578a587
> 
595c604
<     return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
\ No newline at end of file
---
>     return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
  1. Here are the exact commands I ran to produce axis-angle vectors that have norm beyond pi:
import torch 
print(torch.__version__)
import rotation_conversions

data = torch.load('ax_ang_outputs.pt', map_location=torch.device('cpu'))
test_R = data['R_pred'][0][0]

print('This is test rotations')
print(test_R)

# get axis angle from matrix 
ax_ang_from_R = rotation_conversions.matrix_to_axis_angle(test_R)

# check norm 
print(torch.norm(ax_ang_from_R, p=2, dim=-1))
  1. Here's what I observed in the output
1.9.0
This is test rotations
tensor([[[ 0.1494,  0.4837, -0.8624],
         [ 0.7733, -0.6007, -0.2030],
         [-0.6162, -0.6365, -0.4638]],

        [[-0.0714,  0.4292,  0.9004],
         [ 0.9666,  0.2525, -0.0437],
         [-0.2462,  0.8672, -0.4328]],

        [[ 0.5917,  0.2904, -0.7520],
         [ 0.7764,  0.0455,  0.6285],
         [ 0.2167, -0.9558, -0.1986]],

        [[ 0.0833, -0.4428,  0.8927],
         [-0.8289, -0.5281, -0.1846],
         [ 0.5532, -0.7246, -0.4110]],

        [[-0.1795, -0.3386,  0.9237],
         [-0.7774,  0.6242,  0.0778],
         [-0.6029, -0.7041, -0.3752]],

        [[ 0.5805,  0.8141, -0.0176],
         [ 0.2735, -0.2153, -0.9375],
         [-0.7670,  0.5393, -0.3477]],

        [[ 0.4837, -0.4856, -0.7282],
         [-0.7797, -0.6171, -0.1064],
         [-0.3977,  0.6192, -0.6771]],

        [[-0.6256, -0.7073,  0.3291],
         [-0.5258,  0.6939,  0.4920],
         [-0.5763,  0.1348, -0.8060]],

        [[-0.2255,  0.6738,  0.7036],
         [ 0.4855,  0.7039, -0.5185],
         [-0.8446,  0.2247, -0.4859]],

        [[ 0.6165,  0.3816, -0.6888],
         [-0.3325, -0.6667, -0.6670],
         [-0.7137,  0.6402, -0.2841]],

        [[-0.1925, -0.8007, -0.5672],
         [-0.9634,  0.0443,  0.2645],
         [-0.1867,  0.5974, -0.7799]],

        [[-0.8547, -0.0775, -0.5133],
         [-0.2145,  0.9531,  0.2133],
         [ 0.4727,  0.2924, -0.8313]],

        [[-0.2660,  0.8619,  0.4316],
         [-0.9161, -0.3655,  0.1652],
         [ 0.3001, -0.3515,  0.8868]],

        [[ 0.6904,  0.5019, -0.5209],
         [-0.4371,  0.8633,  0.2525],
         [ 0.5764,  0.0534,  0.8154]],

        [[-0.4372, -0.8915, -0.1191],
         [-0.2856,  0.2632, -0.9215],
         [ 0.8528, -0.3688, -0.3696]],

        [[ 0.4659,  0.7022, -0.5384],
         [ 0.1695,  0.5264,  0.8332],
         [ 0.8684, -0.4795,  0.1262]],

        [[-0.2217, -0.6468, -0.7297],
         [ 0.9226,  0.1031, -0.3716],
         [ 0.3156, -0.7557,  0.5739]],

        [[-0.9733, -0.1339,  0.1866],
         [ 0.0948, -0.9742, -0.2048],
         [ 0.2092, -0.1816,  0.9609]],

        [[-0.2019,  0.9766,  0.0736],
         [-0.3784, -0.1471,  0.9139],
         [ 0.9034,  0.1567,  0.3993]],

        [[ 0.3122,  0.2442, -0.9181],
         [ 0.6225,  0.6774,  0.3919],
         [ 0.7177, -0.6939,  0.0595]],

        [[-0.6140, -0.6281, -0.4779],
         [ 0.7845, -0.4196, -0.4565],
         [ 0.0862, -0.6553,  0.7504]],

        [[-0.8525,  0.3715,  0.3678],
         [-0.2567, -0.9104,  0.3245],
         [ 0.4553,  0.1822,  0.8715]],

        [[ 0.1279,  0.9249, -0.3580],
         [-0.0424,  0.3657,  0.9297],
         [ 0.9909, -0.1038,  0.0860]],

        [[ 0.1305, -0.2174, -0.9673],
         [ 0.9185,  0.3937,  0.0354],
         [ 0.3732, -0.8931,  0.2511]],

        [[-0.8646, -0.5018,  0.0267],
         [ 0.4707, -0.8274, -0.3063],
         [ 0.1758, -0.2522,  0.9516]],

        [[-0.4554,  0.8043,  0.3817],
         [-0.3966, -0.5671,  0.7219],
         [ 0.7971,  0.1774,  0.5772]],

        [[ 0.3839,  0.5596, -0.7345],
         [ 0.3787,  0.6300,  0.6780],
         [ 0.8422, -0.5384,  0.0299]],

        [[-0.3423, -0.6588, -0.6699],
         [ 0.9109, -0.0579, -0.4085],
         [ 0.2303, -0.7501,  0.6200]],

        [[-0.9885,  0.0321,  0.1475],
         [-0.0418, -0.9971, -0.0630],
         [ 0.1450, -0.0684,  0.9871]],

        [[-0.5158,  0.8466,  0.1310],
         [ 0.0590,  0.1877, -0.9805],
         [-0.8547, -0.4980, -0.1467]],

        [[-0.2505,  0.3370, -0.9076],
         [-0.8820, -0.4660,  0.0704],
         [-0.3992,  0.8181,  0.4140]],

        [[-0.2778,  0.1515,  0.9486],
         [-0.1057,  0.9767, -0.1869],
         [-0.9548, -0.1522, -0.2553]],

        [[ 0.6881,  0.7063,  0.1661],
         [ 0.0777,  0.1558, -0.9847],
         [-0.7214,  0.6905,  0.0524]],

        [[ 0.3823, -0.2206,  0.8973],
         [-0.2440, -0.9607, -0.1323],
         [ 0.8912, -0.1684, -0.4211]],

        [[ 0.5232,  0.6215,  0.5831],
         [-0.8507,  0.4215,  0.3141],
         [-0.0506, -0.6604,  0.7492]],

        [[-0.0445,  0.9967, -0.0677],
         [-0.8410, -0.0739, -0.5359],
         [-0.5392,  0.0331,  0.8416]],

        [[ 0.7902, -0.5468, -0.2768],
         [-0.5919, -0.7981, -0.1128],
         [-0.1593,  0.2529, -0.9543]],

        [[-0.4070, -0.7162, -0.5670],
         [-0.7906,  0.5871, -0.1740],
         [ 0.4575,  0.3774, -0.8051]],

        [[-0.3221,  0.9386,  0.1233],
         [ 0.8099,  0.3406, -0.4776],
         [-0.4903, -0.0539, -0.8699]],

        [[ 0.6649, -0.2304,  0.7105],
         [ 0.3991,  0.9137, -0.0771],
         [-0.6314,  0.3348,  0.6994]],

        [[-0.1578, -0.6089, -0.7774],
         [-0.5264, -0.6142,  0.5880],
         [-0.8355,  0.5020, -0.2236]],

        [[ 0.9653, -0.2548, -0.0574],
         [ 0.0103, -0.1827,  0.9831],
         [-0.2610, -0.9496, -0.1737]],

        [[-0.0468,  0.6846,  0.7274],
         [ 0.7944,  0.4669, -0.3884],
         [-0.6055,  0.5597, -0.5658]],

        [[ 0.8492, -0.5221, -0.0788],
         [ 0.0201,  0.1810, -0.9833],
         [ 0.5276,  0.8335,  0.1642]],

        [[ 0.8291, -0.5556, -0.0633],
         [ 0.2867,  0.5196, -0.8048],
         [ 0.4800,  0.6491,  0.5901]],

        [[-0.6534, -0.1525, -0.7415],
         [ 0.6651,  0.3521, -0.6585],
         [ 0.3615, -0.9234, -0.1286]],

        [[-0.1768,  0.9773, -0.1165],
         [-0.4400,  0.0274,  0.8976],
         [ 0.8804,  0.2100,  0.4251]],

        [[ 0.1008, -0.7674, -0.6332],
         [ 0.3219,  0.6273, -0.7091],
         [ 0.9414, -0.1323,  0.3103]],

        [[-0.8693,  0.1492,  0.4712],
         [-0.0189, -0.9627,  0.2700],
         [ 0.4939,  0.2258,  0.8397]],

        [[-0.3912, -0.3320, -0.8583],
         [-0.4360,  0.8882, -0.1448],
         [ 0.8104,  0.3176, -0.4922]],

        [[-0.9286, -0.1917, -0.3177],
         [ 0.3708, -0.4515, -0.8116],
         [ 0.0121, -0.8715,  0.4903]],

        [[ 0.1223,  0.0621, -0.9906],
         [-0.9195, -0.3687, -0.1366],
         [-0.3737,  0.9275,  0.0120]],

        [[ 0.1735,  0.8650,  0.4709],
         [ 0.8893, -0.3431,  0.3025],
         [ 0.4232,  0.3662, -0.8287]],

        [[ 0.4478, -0.6575, -0.6059],
         [ 0.7323,  0.6585, -0.1734],
         [ 0.5130, -0.3661,  0.7764]],

        [[-0.4742, -0.1710,  0.8636],
         [ 0.6836, -0.6896,  0.2389],
         [ 0.5547,  0.7037,  0.4440]],

        [[-0.1458,  0.5584, -0.8167],
         [-0.7840, -0.5686, -0.2489],
         [-0.6034,  0.6040,  0.5207]],

        [[-0.4864,  0.8186,  0.3055],
         [ 0.8697,  0.4870,  0.0800],
         [-0.0833,  0.3046, -0.9488]],

        [[-0.1426, -0.9897,  0.0134],
         [ 0.9831, -0.1432, -0.1144],
         [ 0.1152, -0.0031,  0.9933]]], grad_fn=<SelectBackward>)
tensor([3.4340, 2.2470, 4.4279, 3.5236, 2.0547, 2.0843, 2.7027, 2.6237, 2.0987,
        2.3012, 3.4104, 3.6644, 4.3308, 0.8168, 3.8307, 1.5115, 1.8466, 3.0259,
        4.2176, 1.5462, 2.2674, 3.4726, 1.7825, 1.6834, 2.6265, 3.9048, 1.5489,
        1.9715, 3.1787, 2.4000, 4.0031, 1.8528, 1.6227, 3.1633, 1.2165, 4.5735,
        2.9470, 3.7639, 2.7536, 0.8776, 3.2080, 4.5156, 2.1808, 1.4734, 1.0822,
        3.9158, 4.3418, 1.5516, 3.2293, 4.1915, 2.8080, 2.2360, 3.1000, 1.1137,
        2.6060, 4.0728, 2.9137, 1.7175], grad_fn=<NormBackward1>)

davidcjuergens avatar Sep 10 '22 18:09 davidcjuergens

I've written a function which I think normalizes them properly between [0,pi] in magnitude. Here's my attempt, please let me know if you see any issues with this solution.

def th_min_angle(start, end, radians=False):
    """
    Finds the angle you would add to <start> in order to get to <end>
    on the shortest path.
    """
    a,b,c = (np.pi, 2*np.pi, 3*np.pi) if radians else (180, 360, 540)
    shortest_angle = ((((end - start) % b) + c) % b) - a
    return shortest_angle
def normalize_ax_ang(V):
    """
    Gets axis angle representation normalized between [0,pi]
    
    If original AA vector was beyond pi in norm, switches the direction of the vector and scales 
    it appropriate magnitude to represent the angle going around other side of circle 
    """
    V_norm = torch.norm(V, p=2, dim=-1).detach()

    # normalize AA vectors to be norm 1 
    V_magnitude_1 = V / torch.norm(V, p=2, dim=-1)[...,None]

    # calculate the "good norms" - minimum angle between angle and 0
    good_norms = th_min_angle(torch.zeros_like(V_norm), V_norm*180/np.pi, radians=False)*np.pi/180

    # normalize the ax-ang vectors with the good norms 
    # if norm(V) > pi, switches the directon and magnitude of V to go other way around circle 
    V_w_good_norms = V_magnitude_1 * good_norms[...,None]

    return V_w_good_norms

davidcjuergens avatar Sep 10 '22 19:09 davidcjuergens

Your input test_r doesn't look like rotation matrices.

bottler avatar Sep 10 '22 23:09 bottler