theseus
theseus copied to clipboard
AutoDiff not working for SE3 optimization with batched input 3D points
Hi, nice work and thanks for making the repository public!
As a test, I was trying to optimize a matrix variable T=[A|t], given N corresponding 3D points: p_src, p_tgt, with p_tgt = A@p_src + t
. As there is only one matrix variable corresponding to all points, the transformation from p_src and building the cost function are done in a batch. That is, there is only one cost function built for all the points. AutoDiffCostFunction
is used to build the cost function. The optimization won't work due to an error in Jacobian calculation. However, when each of the observation is added separately to the objective (ie. the one cost function for one observation), the error goes away.
The error message for the batched version:
File "/root/theseus/theseus/theseus_layer.py", line 88, in forward
vars, info = _forward(
File "/root/theseus/theseus/theseus_layer.py", line 148, in _forward
info = optimizer.optimize(**optimizer_kwargs)
File "/root/theseus/theseus/optimizer/optimizer.py", line 43, in optimize
return self._optimize_impl(**kwargs)
File "/root/theseus/theseus/optimizer/nonlinear/nonlinear_optimizer.py", line 346, in _optimize_impl
info = self._init_info(
File "/root/theseus/theseus/optimizer/nonlinear/nonlinear_optimizer.py", line 121, in _init_info
last_err = self.objective.error_squared_norm() / 2
File "/root/theseus/theseus/core/objective.py", line 388, in error_squared_norm
self.error(input_tensors=input_tensors, also_update=also_update) ** 2
File "/root/theseus/theseus/core/objective.py", line 375, in error
[cf.weighted_error() for cf in self._get_iterator()], dim=1
File "/root/theseus/theseus/core/objective.py", line 518, in _get_iterator
self.update_vectorization_if_needed()
File "/root/theseus/theseus/core/objective.py", line 510, in update_vectorization_if_needed
self._vectorization_run()
File "/root/theseus/theseus/core/vectorizer.py", line 334, in _vectorize
self._handle_singleton_wrapper(schema, cost_fn_wrappers)
File "/root/theseus/theseus/core/vectorizer.py", line 304, in _handle_singleton_wrapper
) = wrapper.cost_fn.weighted_jacobians_error()
File "/root/theseus/theseus/core/cost_function.py", line 63, in weighted_jacobians_error
jacobian, err = self.jacobians()
File "/root/theseus/theseus/core/cost_function.py", line 218, in jacobians
jacobians_full = [jac[aux_idx, :, aux_idx, :] for jac in jacobians_raw]
File "/root/theseus/theseus/core/cost_function.py", line 218, in <listcomp>
jacobians_full = [jac[aux_idx, :, aux_idx, :] for jac in jacobians_raw]
IndexError: index 1 is out of bounds for dimension 1 with size 1
The code to reproduce the error:
#test function
def test_opt_mSE3(Npts=16,):
import torch
import theseus as th
from theseus.geometry.point_types import Point3
from theseus.geometry.se3 import SE3
import pytorch3d.transforms as py3d_trans
init_T = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]).unsqueeze(0).float()
cam_rot = py3d_trans.random_quaternions(1).reshape(1,4)
cam_tr = torch.tensor([1.,1.,1.]).reshape(1,3)
cam_T = torch.cat((cam_tr, cam_rot), dim=1)
se3_gt= SE3(cam_T, name='pose_gt')
pose= SE3(tensor=init_T, name='pose')
p_src = torch.rand((Npts,3))
p_tgt = se3_gt.transform_from(p_src).tensor
#build cost function
def error_fn(optim_vars, aux_vars):
p_src, p_tgt = aux_vars
pose = optim_vars[0]
p_trans = pose.transform_from(p_src)
diff = (p_trans - p_tgt).tensor
return diff
cost_function = th.AutoDiffCostFunction(
[pose],
error_fn,
3,
aux_vars=[Point3(p_src, name='p_src'),
Point3(p_tgt, name='p_tgt')],
)
#add cost function to objective
objective = th.Objective()
objective.add(cost_function)
theseus_inputs = {
'p_src':Point3(p_src, name='p_src'),
'p_tgt':Point3(p_tgt, name='p_tgt'),
'pose':pose,
}
objective.update(theseus_inputs)
print(f"initial error: {objective.error_squared_norm().sum()}")
#create optimizer and begin optimization
optimizer = th.GaussNewton(
objective,
max_iterations=100,
step_size=0.1,)
th_layer = th.TheseusLayer(optimizer)
res, _ = th_layer.forward(theseus_inputs,)
print(res)
if __name__ == "__main__":
test_opt_mSE3()
The code without batch (worked):
def test_opt_mSE3_singles(Npts=100, use_auto_diff=True):
#add one point at a time
import torch
import theseus as th
from theseus.geometry.point_types import Point3
from theseus.geometry.se3 import SE3
import pytorch3d.transforms as py3d_trans
init_T = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]).unsqueeze(0).float()
import pytorch3d.transforms as py3d_trans
cam_rot = py3d_trans.random_quaternions(1).reshape(1,4)
cam_tr = torch.tensor([1.,1.,1.]).reshape(1,3) #torch.rand((1, 3), ) * 2 + torch.tensor([-1, -1, -1], )
cam_T = torch.cat((cam_tr, cam_rot), dim=1) # Bx7
se3_gt= SE3(cam_T, name='pose_gt')
pose= SE3(tensor=init_T, name='pose')
objective = th.Objective()
if use_auto_diff:
def error_fn(optim_vars, aux_vars):
p_src, p_tgt = aux_vars
pose = optim_vars[0]
p_trans = pose.transform_from(p_src)
diff = (p_trans - p_tgt).tensor
return diff
theseus_inputs = {}
for i in range(Npts): #one cost function for one point
p_src = torch.rand((1,3))
p_tgt = se3_gt.transform_from(p_src).tensor
cost_function = th.AutoDiffCostFunction(
[pose],
error_fn,
3,
aux_vars=[Point3(p_src, name=f'p_src_{i}'),
Point3(p_tgt, name=f'p_tgt_{i}')],
)
theseus_inputs.update({f'p_src_{i}': p_src, f'p_tgt_{i}': p_tgt})
objective.add(cost_function)
else:
raise NotImplementedError()
optimizer = th.GaussNewton(
objective,
max_iterations=20,
step_size=1,
)
th_layer = th.TheseusLayer(optimizer)
print('\nth_layer')
print(th_layer)
print('opt..')
res, _ = th_layer.forward(theseus_inputs,)
Just figured it out. It turns out that the objective class views the optimized variable to be of multiple batches (objective.batch_size=Npts
in this case), while we only have one batch optimized variable.
The trick is to view all the points in the input as a single-batch vector. Below is the modified code for optimizing a rigid transformation T, st. p_tgt = T p_src
def test_opt_mSE3(Npts=20, use_auto_diff=True):
import theseus as th
import torch
from theseus.geometry.se3 import SE3
import pytorch3d.transforms as py3d_trans
#create data
init_T = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]).unsqueeze(0).float()
cam_rot = py3d_trans.random_quaternions(1).reshape(1,4)
cam_tr = torch.tensor([1.,1.,1.]).reshape(1,3)
cam_T = torch.cat((cam_tr, cam_rot), dim=1)
se3_gt= SE3(cam_T, name='pose_gt')
#pose_6d
pose_t, pose_axis_angle = torch.zeros(1,3), torch.zeros(1,3)
pose = torch.cat((pose_t, pose_axis_angle), dim=1)#6D pose
p_src = torch.rand((Npts,3))
p_tgt = se3_gt.transform_from(p_src).tensor
p_src_vec = p_src.reshape(1,-1)
p_tgt_vec = p_tgt.reshape(1,-1)
#build cost function
if use_auto_diff:
def error_fn(optim_vars, aux_vars, ):
p_src, p_tgt = aux_vars
pose_6d= optim_vars[0]
rot = py3d_trans.axis_angle_to_matrix(pose_6d[:, 3:])#N,3,3
t = pose_6d[:,:3]
p_trans = rot @ p_src.tensor.reshape(-1,3,1) + t.unsqueeze(-1)
diff = p_trans.reshape(1,-1) - p_tgt.tensor
return diff
cost_function = th.AutoDiffCostFunction(
[th.Vector(tensor=pose, name='pose', )],
error_fn,
3*Npts, # view all points in the batch as a single vector
aux_vars=[th.Vector(tensor=p_src_vec, name='p_src'),
th.Vector(tensor=p_tgt_vec, name='p_tgt')],
)
else:
raise NotImplementedError()
#add cost function to objective
objective = th.Objective()
objective.add(cost_function)
theseus_inputs = {
'p_src':th.Vector(tensor=p_src_vec, name='p_src'),
'p_tgt':th.Vector(tensor=p_tgt_vec, name='p_tgt'),
}
objective.update(theseus_inputs)
print(f"initial error: {objective.error_squared_norm().sum()}")
#create optimizer and begin optimization
optimizer = th.GaussNewton(
objective,
max_iterations=10,
step_size=1,
)
th_layer = th.TheseusLayer(optimizer)
res, _ = th_layer.forward(theseus_inputs,)
print(f"final error: {objective.error_squared_norm().sum()}")
print('\n===est. pose===')
rot = py3d_trans.axis_angle_to_matrix(res['pose'][:, 3:]).squeeze(0)
T_est = torch.cat((rot, res['pose'][:, :3].squeeze(0).reshape(3,1)), dim=1)
print(T_est)
print('\n===gt pose===')
print(se3_gt.tensor)
if __name__ == "__main__":
test_opt_mSE3()
Ah glad you worked out a fix @cxlcl and thanks for posting it so others can benefit!
Hi @mhmukadam, Can we have batch version transform_from() on points (shape=(batch, N, 3)) for Lie group elements in Theseus? I think the example provided by @cxlcl didn't benefit from the differentiability of Lie groups that is provided by Theseus.
Hi @MickShen7558, transform_from()
already supports batched points (cc @fantaosha). Is the use case different from the current implementation?
Hi @mhmukadam, I think the current implementation only supports transformation where B Lie group elements operate on B points. What if I want B Lie group elements to transform B point clouds where each point cloud contains N points?
Hi @MickShen7558. We were discussing this a bit and @fantaosha agreed to add support for this pretty soon. The case where this is supported for torch.Tensor
type inputs should be easy to add. Would that work for your purposes? Adding this feature for Point3
input types is a bit trickier, since this class requires data to be have exactly 2 dimensions (batch and size 3 data).
@mhmukadam @luisenp This feature (the ability to apply Lie group transformations to pointclouds, and differentiate through it) is super critical for my application. Is there any progress on this issue?
Hi @richardrl. As a matter of fact, we are about to merge this functionality in #513, as part of our new labs.lie package. Below is a simple example that uses this functionality in an optimization problem (you would need to check out the branch linked above).
Let me know if you have any questions.
import theseus as th
import torch
import theseus.labs.lie.functional as lieF
r1 = lieF.SO3.rand(1) # this is going to be our target Rotation
p = torch.randn(1, 10, 3) # batch_size x num_points x dim
p_r1 = lieF.SO3.transform_from(r1, p) # # batch_size x num_points x dim
# Use in Theseus to recover r1 via optimization, starting from a random rotation
r_opt = th.SO3(tensor=lieF.SO3.rand(1), name="r") # start with random R
p_cloud = th.Variable(p, "p_cloud")
p_cloud_target = th.Variable(p_r1, "p_cloud_target")
# The error to optimize is a transform_from between R_opt and the original point cloud
def err_fn(optim_vars, aux_vars):
r_ = optim_vars[0]
p_, p_target = aux_vars
p_r_ = lieF.SO3.transform_from(r_.tensor, p_.tensor)
return (
(p_r_ - p_target.tensor).abs().sum(dim=1)
) # batch_size x dim (sums point cloud dim)
obj = th.Objective()
obj.add(
th.AutoDiffCostFunction(
(r_opt,), err_fn, 3, aux_vars=(p_cloud, p_cloud_target), autograd_mode="dense"
)
)
layer = th.TheseusLayer(th.LevenbergMarquardt(obj))
layer.forward(optimizer_kwargs={"verbose": True, "damping": 0.01})
print((r_opt.tensor - r1).abs().max()) # Check that R now matches original
cc @MickShen7558 @cxlcl @fantaosha
Thanks for your example, However, when I tried this, I got an error as follow:
AttributeError: 'Point3' object has no attribute 'clone'
Here is the code:
import theseus as th
import torch
r1 = th.SO3.rand(1) # this is going to be our target Rotation
p = torch.randn(10, 3) # batch_size x num_points x dim
p_r1 = th.SO3.rotate(r1, p) # # batch_size x num_points x dim
# Use in Theseus to recover r1 via optimization, starting from a random rotation
r_opt = th.SO3(tensor=th.SO3.rand(1), name="r") # start with random R
p_cloud = th.Variable(p, "p_cloud")
p_cloud_target = th.Variable(p_r1, "p_cloud_target")
# The error to optimize is a transform_from between R_opt and the original point cloud
def err_fn(optim_vars, aux_vars):
r_ = optim_vars[0]
p_, p_target = aux_vars
p_r_ = th.SO3.transform_from(r_.tensor, p_.tensor)
return (
(p_r_ - p_target.tensor).abs().sum(dim=1)
) # batch_size x dim (sums point cloud dim)
obj = th.Objective()
obj.add(
th.AutoDiffCostFunction(
(r_opt,), err_fn, 3, aux_vars=(p_cloud, p_cloud_target), autograd_mode="dense"
)
)
layer = th.TheseusLayer(th.LevenbergMarquardt(obj))
layer.forward(optimizer_kwargs={"verbose": True, "damping": 0.01})
print((r_opt.tensor - r1).abs().max()) # Check that R now matches original
I also have a question: In your an other example se2_inverse.py, there is
# Activiate the Lie group update
with th.set_lie_tangent_enabled(use_lie_tangent):
optim.step()
Do I need this to insure the r_opt in SO3 ?
Hi @FanWu-fan, can you post a full stack trace for the error?
Regarding the set_lie_tangent_enabled
, that context makes it so that when doing optim.step()
, instead of simply adding the gradient tensor to the Lie group tensor, it will first project the gradient to tangent space, and then do a retract operation.
@luisenp, Here is the code and error:
import theseus as th
import torch
r1 = th.SO3.rand(1) # this is going to be our target Rotation
p = torch.randn(10, 3) # batch_size x num_points x dim
p_r1 = th.SO3.rotate(r1, p) # # batch_size x num_points x dim
# Use in Theseus to recover r1 via optimization, starting from a random rotation
r_opt = th.SO3(tensor=th.SO3.rand(1), name="r") # start with random R
p_cloud = th.Variable(p, "p_cloud")
p_cloud_target = th.Variable(p_r1, "p_cloud_target")
# The error to optimize is a transform_from between R_opt and the original point cloud
def err_fn(optim_vars, aux_vars):
r_ = optim_vars[0]
p_, p_target = aux_vars
p_r_ = th.SO3.transform_from(r_.tensor, p_.tensor)
return (
(p_r_ - p_target.tensor).abs().sum(dim=1)
) # batch_size x dim (sums point cloud dim)
obj = th.Objective()
obj.add(
th.AutoDiffCostFunction(
(r_opt,), err_fn, 3, aux_vars=(p_cloud, p_cloud_target), autograd_mode="dense"
)
)
layer = th.TheseusLayer(th.LevenbergMarquardt(obj))
layer.forward(optimizer_kwargs={"verbose": True, "damping": 0.01})
print((r_opt.tensor - r1).abs().max()) # Check that R now matches original
# Traceback (most recent call last):
# File "/home/fan/Rotation_py/try2.py", line 29, in <module>
# layer = th.TheseusLayer(th.LevenbergMarquardt(obj))
# File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/theseus_layer.py", line 39, in __init__
# Vectorize(self.objective, empty_cuda_cache=empty_cuda_cache)
# File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/core/vectorizer.py", line 136, in __init__
# vectorized_cost_fn = base_cost_fn.copy(keep_variable_names=False)
# File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/core/cost_function.py", line 135, in copy
# super().copy(new_name=new_name, keep_variable_names=keep_variable_names),
# File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/core/theseus_function.py", line 95, in copy
# new_fn = self._copy_impl(new_name=new_name)
# File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/core/cost_function.py", line 403, in _copy_impl
# aux_vars=[v.copy() for v in self.aux_vars],
# File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/core/cost_function.py", line 403, in <listcomp>
# aux_vars=[v.copy() for v in self.aux_vars],
# File "/home/fan/miniconda3/envs/t11/lib/python3.10/site-packages/theseus/core/variable.py", line 31, in copy
# return Variable(self.tensor.clone(), name=new_name)
# AttributeError: 'Point3' object has no attribute 'clone'
According to your reply and the example se2_inverse.py , I need to explicit call the function optim.step()
. What if I want to use the th.TheseusLayer
? Here is the code to calculate the rotation matrix R, in which P2 = R@P1, and P2 is a vector observed by the sensor.
import torch
import theseus as th
from theseus.geometry.point_types import Point3
from theseus.geometry.se3 import SE3
import pytorch3d.transforms as py3d_trans
def test_opt_mSO3(so3_gt,r_opt,Npts=320,num_iters=1000,use_lie_tangent=False):
print("so3-gt shape: ",so3_gt.shape) # 1x3x3
p_src = torch.rand((Npts,3))
print("p_src shape: ", p_src.shape) # Nptsx3
p_tgt = so3_gt.rotate(p_src)
p_tgt_noise = torch.randn(Npts,3)
p_tgt.tensor = p_tgt.tensor + p_tgt_noise
print("p_tgt shape: ", p_tgt.shape) # Nptsx3
r_opt.tensor.requires_grad=True
#build cost function
optim = torch.optim.Adam([r_opt.tensor], lr=1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optim, milestones=[1000, 3000], gamma=0.1
)
for i in range(num_iters):
optim.zero_grad()
cf = th.Difference(r_opt.rotate(p_src), p_tgt, th.ScaleCostWeight(1.0))
loss = cf.error().norm()
if i % 100 == 0:
print(
"iter {}: loss is {:.10f}, determinate is {:.10f}".format(
num_iters, loss.item(), torch.det(r_opt.tensor[0])
))
loss.backward()
# Activiate the Lie group update
with th.set_lie_tangent_enabled(use_lie_tangent):
optim.step()
scheduler.step()
cf = th.Difference(r_opt.rotate(p_src), p_tgt, th.ScaleCostWeight(1.0))
loss = cf.error().norm()
print(
"iter {}: loss is {:.10f}, determinate is {:.10f}".format(
num_iters, loss.item(), torch.det(r_opt.tensor[0])
))
print("r gt: ",so3_gt, "r opt: ", r_opt)
if __name__ == "__main__":
so3_gt= th.SO3.rand(1)
r_opt= th.SO3.rand(1)
print("=========================================================")
print("PyTorch Optimization on the Euclidean Space")
test_opt_mSO3(so3_gt,r_opt,use_lie_tangent=False)
print("\n")
print("=========================================================")
print("PyTorch Optimization on the Lie Group Tangent Space (Ours)")
test_opt_mSO3(so3_gt,r_opt,use_lie_tangent=True)
print("\n")
Hi @FanWu-fan. The problem is that you are importing SO3
from th
instead of lieF
; these are different objects. th.SO3
is a Variable
subclass, which wraps a tensor so that it can be used as a variable inside theseus optimizers. On the other hand, lieF.SO3
is a namespace of tensor-tensor operations for SO3 data.
I'm not sure that I understand your second question, but, in general, you don't need to call optim.step()
to use a TheseusLayer
. The code in se2_inverse.py
is trying to illustrate how to minimize a torch loss with lie group parameters, outside the context of NLLS optimization. It is also and old example, and superseded by our torchlie
package; we should probably remove it at this point.
A good rule of thumb is:
- If you need to model an optimization/aux variable in one of our optimizers, use a subclass of
th.LieGroup
. - If you need to do operations on regular tensors representing Lie group data, use
torchlie
(e.g.,lieF
). - If you have a torch optimizer with Lie group parameters, also use
torchlie
.
Combinations of all three are possible and typical. The example above is a combination of 1 and 2: it shows an SO3 variable that uses torchlie
ops inside an autodiff cost function. The se2_inverse
is an example of case 3, although this can be done more easily with torchlie
now.
I'm late to the party, but looking at the solution that @cxlcl suggested, I may be wrong but I don't think you're optimizing the same objective. By vectorizing all your 3D points in a single vector and having a single objective component, you're moving the sum over your points inside the 2-norm of the NLS (rather than optimizing a sum of 2-norms for all points). I don't know if Theseus has a way of vectorizing an objective function such as fitting 3d point-clouds, without the need of adding each objective component (each with a 3-dim error function).