theseus icon indicating copy to clipboard operation
theseus copied to clipboard

AutoDiff not working for SE3 optimization with batched input 3D points

Open cxlcl opened this issue 1 year ago • 13 comments

Hi, nice work and thanks for making the repository public!

As a test, I was trying to optimize a matrix variable T=[A|t], given N corresponding 3D points: p_src, p_tgt, with p_tgt = A@p_src + t. As there is only one matrix variable corresponding to all points, the transformation from p_src and building the cost function are done in a batch. That is, there is only one cost function built for all the points. AutoDiffCostFunction is used to build the cost function. The optimization won't work due to an error in Jacobian calculation. However, when each of the observation is added separately to the objective (ie. the one cost function for one observation), the error goes away.

The error message for the batched version:

File "/root/theseus/theseus/theseus_layer.py", line 88, in forward
    vars, info = _forward(
  File "/root/theseus/theseus/theseus_layer.py", line 148, in _forward
    info = optimizer.optimize(**optimizer_kwargs)
  File "/root/theseus/theseus/optimizer/optimizer.py", line 43, in optimize
    return self._optimize_impl(**kwargs)
  File "/root/theseus/theseus/optimizer/nonlinear/nonlinear_optimizer.py", line 346, in _optimize_impl
    info = self._init_info(
  File "/root/theseus/theseus/optimizer/nonlinear/nonlinear_optimizer.py", line 121, in _init_info
    last_err = self.objective.error_squared_norm() / 2
  File "/root/theseus/theseus/core/objective.py", line 388, in error_squared_norm
    self.error(input_tensors=input_tensors, also_update=also_update) ** 2
  File "/root/theseus/theseus/core/objective.py", line 375, in error
    [cf.weighted_error() for cf in self._get_iterator()], dim=1
  File "/root/theseus/theseus/core/objective.py", line 518, in _get_iterator
    self.update_vectorization_if_needed()
  File "/root/theseus/theseus/core/objective.py", line 510, in update_vectorization_if_needed
    self._vectorization_run()
  File "/root/theseus/theseus/core/vectorizer.py", line 334, in _vectorize
    self._handle_singleton_wrapper(schema, cost_fn_wrappers)
  File "/root/theseus/theseus/core/vectorizer.py", line 304, in _handle_singleton_wrapper
    ) = wrapper.cost_fn.weighted_jacobians_error()
  File "/root/theseus/theseus/core/cost_function.py", line 63, in weighted_jacobians_error
    jacobian, err = self.jacobians()
  File "/root/theseus/theseus/core/cost_function.py", line 218, in jacobians
    jacobians_full = [jac[aux_idx, :, aux_idx, :] for jac in jacobians_raw]
  File "/root/theseus/theseus/core/cost_function.py", line 218, in <listcomp>
    jacobians_full = [jac[aux_idx, :, aux_idx, :] for jac in jacobians_raw]
IndexError: index 1 is out of bounds for dimension 1 with size 1

The code to reproduce the error:

#test function
def test_opt_mSE3(Npts=16,):
    import torch
    import theseus as th
    from theseus.geometry.point_types import Point3 
    from theseus.geometry.se3 import SE3
    import pytorch3d.transforms as py3d_trans 
    init_T = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]).unsqueeze(0).float()
    cam_rot = py3d_trans.random_quaternions(1).reshape(1,4)
    cam_tr =  torch.tensor([1.,1.,1.]).reshape(1,3)
    cam_T = torch.cat((cam_tr, cam_rot), dim=1)
    se3_gt= SE3(cam_T, name='pose_gt')
    pose= SE3(tensor=init_T, name='pose') 
    p_src = torch.rand((Npts,3))
    p_tgt = se3_gt.transform_from(p_src).tensor 
    #build cost function
    def error_fn(optim_vars, aux_vars):
        p_src, p_tgt = aux_vars
        pose = optim_vars[0]
        p_trans = pose.transform_from(p_src)
        diff = (p_trans - p_tgt).tensor
        return diff
    cost_function = th.AutoDiffCostFunction(
        [pose],
        error_fn,
        3,
        aux_vars=[Point3(p_src, name='p_src'), 
                  Point3(p_tgt, name='p_tgt')],
    )

    #add cost function to objective
    objective = th.Objective()
    objective.add(cost_function)
    theseus_inputs = {
        'p_src':Point3(p_src, name='p_src'), 
        'p_tgt':Point3(p_tgt, name='p_tgt'), 
        'pose':pose,
        }
    objective.update(theseus_inputs)
    print(f"initial error: {objective.error_squared_norm().sum()}")

    #create optimizer and begin optimization
    optimizer = th.GaussNewton(
        objective,
        max_iterations=100,
        step_size=0.1,)
    th_layer = th.TheseusLayer(optimizer)
    res, _ = th_layer.forward(theseus_inputs,)
    print(res)
if __name__ == "__main__":
    test_opt_mSE3()

The code without batch (worked):

def test_opt_mSE3_singles(Npts=100, use_auto_diff=True):
    #add one point at a time 
    import torch
    import theseus as th
    from theseus.geometry.point_types import Point3 
    from theseus.geometry.se3 import SE3
    import pytorch3d.transforms as py3d_trans 
    init_T = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]).unsqueeze(0).float()
    import pytorch3d.transforms as py3d_trans
    cam_rot = py3d_trans.random_quaternions(1).reshape(1,4)
    cam_tr =  torch.tensor([1.,1.,1.]).reshape(1,3) #torch.rand((1, 3), ) * 2 + torch.tensor([-1, -1, -1], )
    cam_T = torch.cat((cam_tr, cam_rot), dim=1) # Bx7
    se3_gt= SE3(cam_T, name='pose_gt') 
    pose= SE3(tensor=init_T, name='pose') 

    objective = th.Objective()
    if use_auto_diff:
        def error_fn(optim_vars, aux_vars):
            p_src, p_tgt = aux_vars
            pose = optim_vars[0]
            p_trans = pose.transform_from(p_src)
            diff = (p_trans - p_tgt).tensor
            return diff

        theseus_inputs = {}
        for i in range(Npts): #one cost function for one point
            p_src = torch.rand((1,3))
            p_tgt = se3_gt.transform_from(p_src).tensor
            cost_function = th.AutoDiffCostFunction(
                [pose],
                error_fn,
                3,
                aux_vars=[Point3(p_src, name=f'p_src_{i}'), 
                          Point3(p_tgt, name=f'p_tgt_{i}')],
            )
            theseus_inputs.update({f'p_src_{i}': p_src, f'p_tgt_{i}': p_tgt})
            objective.add(cost_function)
    else:
        raise NotImplementedError()

    optimizer = th.GaussNewton(
        objective,
        max_iterations=20,
        step_size=1,
    ) 

    th_layer = th.TheseusLayer(optimizer)
    print('\nth_layer')
    print(th_layer)

    print('opt..')
    res, _ = th_layer.forward(theseus_inputs,)

cxlcl avatar Aug 30 '22 14:08 cxlcl

Just figured it out. It turns out that the objective class views the optimized variable to be of multiple batches (objective.batch_size=Npts in this case), while we only have one batch optimized variable.

The trick is to view all the points in the input as a single-batch vector. Below is the modified code for optimizing a rigid transformation T, st. p_tgt = T p_src

def test_opt_mSE3(Npts=20, use_auto_diff=True):
    import theseus as th
    import torch
    from theseus.geometry.se3 import SE3
    import pytorch3d.transforms as py3d_trans 

    #create data
    init_T = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]).unsqueeze(0).float()
    cam_rot = py3d_trans.random_quaternions(1).reshape(1,4)
    cam_tr =  torch.tensor([1.,1.,1.]).reshape(1,3)
    cam_T = torch.cat((cam_tr, cam_rot), dim=1)
    se3_gt= SE3(cam_T, name='pose_gt')

    #pose_6d
    pose_t, pose_axis_angle = torch.zeros(1,3), torch.zeros(1,3)
    pose = torch.cat((pose_t, pose_axis_angle), dim=1)#6D pose

    p_src = torch.rand((Npts,3))
    p_tgt = se3_gt.transform_from(p_src).tensor
    p_src_vec = p_src.reshape(1,-1)
    p_tgt_vec = p_tgt.reshape(1,-1)

    #build cost function
    if use_auto_diff:
        def error_fn(optim_vars, aux_vars, ):
            p_src, p_tgt = aux_vars
            pose_6d= optim_vars[0]
            rot = py3d_trans.axis_angle_to_matrix(pose_6d[:, 3:])#N,3,3
            t = pose_6d[:,:3]
            p_trans = rot @ p_src.tensor.reshape(-1,3,1) + t.unsqueeze(-1)
            diff = p_trans.reshape(1,-1) - p_tgt.tensor
            return diff

        cost_function = th.AutoDiffCostFunction(
            [th.Vector(tensor=pose, name='pose', )], 
            error_fn,
            3*Npts, # view all points in the batch as a single vector
            aux_vars=[th.Vector(tensor=p_src_vec, name='p_src'), 
                      th.Vector(tensor=p_tgt_vec, name='p_tgt')],
        )
    else:
        raise NotImplementedError()
        

    #add cost function to objective
    objective = th.Objective()
    objective.add(cost_function)
    theseus_inputs = {
        'p_src':th.Vector(tensor=p_src_vec, name='p_src'),
        'p_tgt':th.Vector(tensor=p_tgt_vec, name='p_tgt'),
        }
    objective.update(theseus_inputs)
    
    print(f"initial error: {objective.error_squared_norm().sum()}")

    #create optimizer and begin optimization
    optimizer = th.GaussNewton(
        objective,
        max_iterations=10,
        step_size=1,
    )
    th_layer = th.TheseusLayer(optimizer)
    res, _ = th_layer.forward(theseus_inputs,)
    print(f"final error: {objective.error_squared_norm().sum()}")
    
    print('\n===est. pose===')
    rot = py3d_trans.axis_angle_to_matrix(res['pose'][:, 3:]).squeeze(0)
    T_est = torch.cat((rot, res['pose'][:, :3].squeeze(0).reshape(3,1)), dim=1)
    print(T_est)
    print('\n===gt pose===')
    print(se3_gt.tensor)


if __name__ == "__main__":
    test_opt_mSE3()

cxlcl avatar Aug 30 '22 21:08 cxlcl

Ah glad you worked out a fix @cxlcl and thanks for posting it so others can benefit!

mhmukadam avatar Aug 30 '22 21:08 mhmukadam

Hi @mhmukadam, Can we have batch version transform_from() on points (shape=(batch, N, 3)) for Lie group elements in Theseus? I think the example provided by @cxlcl didn't benefit from the differentiability of Lie groups that is provided by Theseus.

MickShen7558 avatar Sep 12 '22 05:09 MickShen7558

Hi @MickShen7558, transform_from() already supports batched points (cc @fantaosha). Is the use case different from the current implementation?

mhmukadam avatar Sep 12 '22 14:09 mhmukadam

Hi @mhmukadam, I think the current implementation only supports transformation where B Lie group elements operate on B points. What if I want B Lie group elements to transform B point clouds where each point cloud contains N points?

MickShen7558 avatar Sep 12 '22 17:09 MickShen7558

Hi @MickShen7558. We were discussing this a bit and @fantaosha agreed to add support for this pretty soon. The case where this is supported for torch.Tensor type inputs should be easy to add. Would that work for your purposes? Adding this feature for Point3 input types is a bit trickier, since this class requires data to be have exactly 2 dimensions (batch and size 3 data).

luisenp avatar Sep 12 '22 17:09 luisenp

@mhmukadam @luisenp This feature (the ability to apply Lie group transformations to pointclouds, and differentiate through it) is super critical for my application. Is there any progress on this issue?

richardrl avatar May 08 '23 00:05 richardrl

Hi @richardrl. As a matter of fact, we are about to merge this functionality in #513, as part of our new labs.lie package. Below is a simple example that uses this functionality in an optimization problem (you would need to check out the branch linked above).

Let me know if you have any questions.

import theseus as th
import torch
import theseus.labs.lie.functional as lieF

r1 = lieF.SO3.rand(1)   # this is going to be our target Rotation
p = torch.randn(1, 10, 3)  # batch_size x num_points x dim
p_r1 = lieF.SO3.transform_from(r1, p)  # # batch_size x num_points x dim

# Use in Theseus to recover r1 via optimization, starting from a random rotation
r_opt = th.SO3(tensor=lieF.SO3.rand(1), name="r")  # start with random R
p_cloud = th.Variable(p, "p_cloud")
p_cloud_target = th.Variable(p_r1, "p_cloud_target")

# The error to optimize is a transform_from between R_opt and the original point cloud
def err_fn(optim_vars, aux_vars):
    r_ = optim_vars[0]
    p_, p_target = aux_vars
    p_r_ = lieF.SO3.transform_from(r_.tensor, p_.tensor)
    return (
        (p_r_ - p_target.tensor).abs().sum(dim=1)
    )  # batch_size x dim (sums point cloud dim)


obj = th.Objective()
obj.add(
    th.AutoDiffCostFunction(
        (r_opt,), err_fn, 3, aux_vars=(p_cloud, p_cloud_target), autograd_mode="dense"
    )
)
layer = th.TheseusLayer(th.LevenbergMarquardt(obj))
layer.forward(optimizer_kwargs={"verbose": True, "damping": 0.01})

print((r_opt.tensor - r1).abs().max())  # Check that R now matches original

cc @MickShen7558 @cxlcl @fantaosha

luisenp avatar May 08 '23 13:05 luisenp

Thanks for your example, However, when I tried this, I got an error as follow:

AttributeError: 'Point3' object has no attribute 'clone'

Here is the code:

import theseus as th
import torch

r1 = th.SO3.rand(1)   # this is going to be our target Rotation
p = torch.randn(10, 3)  # batch_size x num_points x dim
p_r1 = th.SO3.rotate(r1, p)  # # batch_size x num_points x dim

# Use in Theseus to recover r1 via optimization, starting from a random rotation
r_opt = th.SO3(tensor=th.SO3.rand(1), name="r")  # start with random R
p_cloud = th.Variable(p, "p_cloud")
p_cloud_target = th.Variable(p_r1, "p_cloud_target")

# The error to optimize is a transform_from between R_opt and the original point cloud
def err_fn(optim_vars, aux_vars):
    r_ = optim_vars[0]
    p_, p_target = aux_vars
    p_r_ = th.SO3.transform_from(r_.tensor, p_.tensor)
    return (
        (p_r_ - p_target.tensor).abs().sum(dim=1)
    )  # batch_size x dim (sums point cloud dim)


obj = th.Objective()
obj.add(
    th.AutoDiffCostFunction(
        (r_opt,), err_fn, 3, aux_vars=(p_cloud, p_cloud_target), autograd_mode="dense"
    )
)
layer = th.TheseusLayer(th.LevenbergMarquardt(obj))
layer.forward(optimizer_kwargs={"verbose": True, "damping": 0.01})

print((r_opt.tensor - r1).abs().max())  # Check that R now matches original

I also have a question: In your an other example se2_inverse.py, there is

# Activiate the Lie group update
with th.set_lie_tangent_enabled(use_lie_tangent):
    optim.step()         

Do I need this to insure the r_opt in SO3 ?

FanWu-fan avatar Aug 07 '23 19:08 FanWu-fan

Hi @FanWu-fan, can you post a full stack trace for the error?

Regarding the set_lie_tangent_enabled, that context makes it so that when doing optim.step(), instead of simply adding the gradient tensor to the Lie group tensor, it will first project the gradient to tangent space, and then do a retract operation.

luisenp avatar Aug 07 '23 19:08 luisenp

@luisenp, Here is the code and error:

import theseus as th
import torch

r1 = th.SO3.rand(1)   # this is going to be our target Rotation
p = torch.randn(10, 3)  # batch_size x num_points x dim
p_r1 = th.SO3.rotate(r1, p)  # # batch_size x num_points x dim

# Use in Theseus to recover r1 via optimization, starting from a random rotation
r_opt = th.SO3(tensor=th.SO3.rand(1), name="r")  # start with random R
p_cloud = th.Variable(p, "p_cloud")
p_cloud_target = th.Variable(p_r1, "p_cloud_target")

# The error to optimize is a transform_from between R_opt and the original point cloud
def err_fn(optim_vars, aux_vars):
    r_ = optim_vars[0]
    p_, p_target = aux_vars
    p_r_ = th.SO3.transform_from(r_.tensor, p_.tensor)
    return (
        (p_r_ - p_target.tensor).abs().sum(dim=1)
    )  # batch_size x dim (sums point cloud dim)


obj = th.Objective()
obj.add(
    th.AutoDiffCostFunction(
        (r_opt,), err_fn, 3, aux_vars=(p_cloud, p_cloud_target), autograd_mode="dense"
    )
)
layer = th.TheseusLayer(th.LevenbergMarquardt(obj))
layer.forward(optimizer_kwargs={"verbose": True, "damping": 0.01})

print((r_opt.tensor - r1).abs().max())  # Check that R now matches original

# Traceback (most recent call last):
#   File "/home/fan/Rotation_py/try2.py", line 29, in <module>
#     layer = th.TheseusLayer(th.LevenbergMarquardt(obj))
#   File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/theseus_layer.py", line 39, in __init__
#     Vectorize(self.objective, empty_cuda_cache=empty_cuda_cache)
#   File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/core/vectorizer.py", line 136, in __init__
#     vectorized_cost_fn = base_cost_fn.copy(keep_variable_names=False)
#   File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/core/cost_function.py", line 135, in copy
#     super().copy(new_name=new_name, keep_variable_names=keep_variable_names),
#   File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/core/theseus_function.py", line 95, in copy
#     new_fn = self._copy_impl(new_name=new_name)
#   File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/core/cost_function.py", line 403, in _copy_impl
#     aux_vars=[v.copy() for v in self.aux_vars],
#   File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/core/cost_function.py", line 403, in <listcomp>
#     aux_vars=[v.copy() for v in self.aux_vars],
#   File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/core/variable.py", line 31, in copy
#     return Variable(self.tensor.clone(), name=new_name)
# AttributeError: 'Point3' object has no attribute 'clone'

According to your reply and the example se2_inverse.py , I need to explicit call the function optim.step(). What if I want to use the th.TheseusLayer? Here is the code to calculate the rotation matrix R, in which P2 = R@P1, and P2 is a vector observed by the sensor.

import torch
import theseus as th
from theseus.geometry.point_types import Point3 
from theseus.geometry.se3 import SE3
import pytorch3d.transforms as py3d_trans


def test_opt_mSO3(so3_gt,r_opt,Npts=320,num_iters=1000,use_lie_tangent=False):
 
    print("so3-gt shape: ",so3_gt.shape) # 1x3x3
    p_src = torch.rand((Npts,3))
    print("p_src shape: ", p_src.shape) # Nptsx3
    p_tgt = so3_gt.rotate(p_src)
    p_tgt_noise = torch.randn(Npts,3)
    p_tgt.tensor = p_tgt.tensor + p_tgt_noise
    print("p_tgt shape: ", p_tgt.shape) # Nptsx3

    r_opt.tensor.requires_grad=True
    #build cost function
    optim = torch.optim.Adam([r_opt.tensor], lr=1)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optim, milestones=[1000, 3000], gamma=0.1
    )
    for i in range(num_iters):
        optim.zero_grad()
        cf = th.Difference(r_opt.rotate(p_src), p_tgt, th.ScaleCostWeight(1.0))
        loss = cf.error().norm()
        if i % 100 == 0:
            print(
        "iter {}: loss is {:.10f}, determinate is {:.10f}".format(
            num_iters, loss.item(), torch.det(r_opt.tensor[0])
        ))
        loss.backward()

        # Activiate the Lie group update
        with th.set_lie_tangent_enabled(use_lie_tangent):
            optim.step()

        scheduler.step()
    cf = th.Difference(r_opt.rotate(p_src), p_tgt, th.ScaleCostWeight(1.0))
    loss = cf.error().norm()
    print(
    "iter {}: loss is {:.10f}, determinate is {:.10f}".format(
        num_iters, loss.item(), torch.det(r_opt.tensor[0])
    ))
    print("r gt: ",so3_gt, "r opt: ", r_opt)

if __name__ == "__main__":
    so3_gt= th.SO3.rand(1)
    r_opt= th.SO3.rand(1)
    print("=========================================================")
    print("PyTorch Optimization on the Euclidean Space")
    test_opt_mSO3(so3_gt,r_opt,use_lie_tangent=False)
    print("\n")

    print("=========================================================")
    print("PyTorch Optimization on the Lie Group Tangent Space (Ours)")
    test_opt_mSO3(so3_gt,r_opt,use_lie_tangent=True)
    print("\n")

FanWu-fan avatar Aug 08 '23 07:08 FanWu-fan

Hi @FanWu-fan. The problem is that you are importing SO3 from th instead of lieF; these are different objects. th.SO3 is a Variable subclass, which wraps a tensor so that it can be used as a variable inside theseus optimizers. On the other hand, lieF.SO3 is a namespace of tensor-tensor operations for SO3 data.

I'm not sure that I understand your second question, but, in general, you don't need to call optim.step() to use a TheseusLayer. The code in se2_inverse.py is trying to illustrate how to minimize a torch loss with lie group parameters, outside the context of NLLS optimization. It is also and old example, and superseded by our torchlie package; we should probably remove it at this point.

A good rule of thumb is:

  1. If you need to model an optimization/aux variable in one of our optimizers, use a subclass of th.LieGroup.
  2. If you need to do operations on regular tensors representing Lie group data, use torchlie (e.g., lieF).
  3. If you have a torch optimizer with Lie group parameters, also use torchlie.

Combinations of all three are possible and typical. The example above is a combination of 1 and 2: it shows an SO3 variable that uses torchlie ops inside an autodiff cost function. The se2_inverse is an example of case 3, although this can be done more easily with torchlie now.

luisenp avatar Aug 08 '23 12:08 luisenp

I'm late to the party, but looking at the solution that @cxlcl suggested, I may be wrong but I don't think you're optimizing the same objective. By vectorizing all your 3D points in a single vector and having a single objective component, you're moving the sum over your points inside the 2-norm of the NLS (rather than optimizing a sum of 2-norms for all points). I don't know if Theseus has a way of vectorizing an objective function such as fitting 3d point-clouds, without the need of adding each objective component (each with a 3-dim error function).

byfron avatar Apr 05 '24 11:04 byfron