pytorch3d
pytorch3d copied to clipboard
Any plan to support `torch.func`?
I am trying to compute full jacobian using jacrev or jacfwd from torch.func. Part of the loss function uses _PointFaceDistance. Out of the box, pytorch3d does not support torch.func. The closest references I can find so far are https://github.com/facebookresearch/pytorch3d/issues/1636 and https://github.com/facebookresearch/pytorch3d/issues/1533.
The problems I am having are
- With
jacrev, it throws an exceptionRuntimeError: Cannot access data pointer of Tensor that doesn't have storageat the line of_C.point_face_dist_backward. - I followed https://github.com/facebookresearch/pytorch3d/issues/1533 and implementedvmapfor_PointFaceDistance, butjacrevstill hits the line of_C.point_face_dist_backwardfirst, so same error. #1533 happened in the forward path, so actually same error but different cause. - with
jacfwd, it throws an exception of missingjvpmethod. - I don't quite know how to implementjvpas I have not found much info on what exactly should go in to this function.
The only working method is to call torch.autograd.functional.jacobian(vectorize=False) which is very slow. And when turn on vectorize=True, it runs into the same issues as above.
My questions are:
- is there a plan to officially support
torch.func? If I can get some guidance frompytorch3dteam, I am happy to collaborate on this. - Any idea how to make this work? Any workarounds?
Thanks!
We aren't planning torch.func support. It seems to me that the method in #1533 should work fine for _PointFaceDistance - feel free to post the code you have and maybe we can figure out what's wrong.
@bottler Thanks for the reply!
The error msg "RuntimeError: Cannot access data pointer of Tensor that doesn't have storage" from #1533 happened in the forward pass "idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version). I don't know how #1533 is computing the jacobian. In my case, neither forward mode nor backward mode work with following errors:
- With backward mode
jacrev, the same error msg happens at the time when calling_C.point_face_dist_backwardinside ofbackward. - With forward mode
jacfwd, it complains about missingjvpmethod.
I did follow #1533 to modify the class with setup_context etc. I also added vmap but vmap method is never called before hitting two errors above.
@bottler I wrote a toy example that follows this pytorch3d tutorial with following modifications:
- Only use
_PointFaceDistancein the objective. Because I only care if we can compute jacobian, it does not matter if the optimization actually runs. - Added
vmapto_PointFaceDistanceand addedsetup_context. - Used
theseus
Here is the code. It is self-contained and will download dolphin.obj following the pytorch3d tutorial. Sorry that the code is a bit long to include _PointFaceDistance updates.
import os
import urllib.request
import einops
import theseus as th
import torch
from pytorch3d import _C
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from torch.autograd import Function
from torch.autograd.function import once_differentiable
_DEFAULT_MIN_TRIANGLE_AREA: float = 5e-3
# PointFaceDistance
class _PointFaceDistance(Function):
"""
Torch autograd Function wrapper PointFaceDistance Cuda implementation
"""
generate_vmap_rule = False
@staticmethod
def forward(
# ctx,
points,
points_first_idx,
tris,
tris_first_idx,
max_points,
min_triangle_area=_DEFAULT_MIN_TRIANGLE_AREA,
):
"""
Args:
ctx: Context object used to calculate gradients.
points: FloatTensor of shape `(P, 3)`
points_first_idx: LongTensor of shape `(N,)` indicating the first point
index in each example in the batch
tris: FloatTensor of shape `(T, 3, 3)` of triangular faces. The `t`-th
triangular face is spanned by `(tris[t, 0], tris[t, 1], tris[t, 2])`
tris_first_idx: LongTensor of shape `(N,)` indicating the first face
index in each example in the batch
max_points: Scalar equal to maximum number of points in the batch
min_triangle_area: (float, defaulted) Triangles of area less than this
will be treated as points/lines.
Returns:
dists: FloatTensor of shape `(P,)`, where `dists[p]` is the squared
euclidean distance of `p`-th point to the closest triangular face
in the corresponding example in the batch
idxs: LongTensor of shape `(P,)` indicating the closest triangular face
in the corresponding example in the batch.
`dists[p]` is
`d(points[p], tris[idxs[p], 0], tris[idxs[p], 1], tris[idxs[p], 2])`
where `d(u, v0, v1, v2)` is the distance of point `u` from the triangular
face `(v0, v1, v2)`
"""
dists, idxs = _C.point_face_dist_forward(
points,
points_first_idx,
tris,
tris_first_idx,
max_points,
min_triangle_area,
)
# ctx.save_for_backward(points, tris, idxs)
# ctx.min_triangle_area = min_triangle_area
return dists, idxs
@staticmethod
def setup_context(ctx, inputs, output):
(
points,
points_first_idx,
tris,
tris_first_idx,
max_tris,
min_triangle_area,
) = inputs
dists, idxs = output
ctx.save_for_backward(points, tris, idxs)
ctx.min_triangle_area = min_triangle_area
ctx.dists = dists
ctx.idxs = idxs
ctx.inputs = inputs
@staticmethod
@once_differentiable
def backward(ctx, grad_dists, grad_idxs):
grad_dists = grad_dists.contiguous()
points, tris, idxs = ctx.saved_tensors
min_triangle_area = ctx.min_triangle_area
grad_points, grad_tris = _C.point_face_dist_backward(
points, tris, idxs, grad_dists, min_triangle_area
)
return grad_points, None, grad_tris, None, None, None
@staticmethod
def vmap(
info,
in_idms,
points,
points_first_idx,
tris,
tris_first_idx,
max_points,
min_triangle_area=_DEFAULT_MIN_TRIANGLE_AREA,
):
(
points_bdim,
points_first_idx_bdim,
tris_bdm,
tris_first_idx_bdim,
_,
_,
) = in_idms
points_V, points_P, points_C = points.shape
points = einops.rearrange(points, "V P C -> (V P) C")
tris = einops.rearrange(tris, "V T A B -> (V T) A B")
dists, idx = _PointFaceDistance.forward(
points,
points_first_idx,
tris,
tris_first_idx,
max_points,
min_triangle_area,
)
dists = einops.rearrange(dists, "(V P) -> V P", V=points_V)
idx = einops.rearrange(idx, "(V P) -> V P", V=points_V)
return (dists, idx), (0, 0)
point_face_distance = _PointFaceDistance.apply
def point_to_mesh_distance(points, mesh_v, mesh_f):
scale_fac = 100.0 # see explanation above
# packed representation for pointclouds
points = points * scale_fac # (P, 3)
points_first_idx = torch.zeros([1])
max_points = points.shape[0]
# packed representation for faces
verts_packed = mesh_v * scale_fac
faces_packed = mesh_f
tris = verts_packed[faces_packed.to(torch.int)]
tris_first_idx = torch.zeros([1])
point_to_face, _ = point_face_distance(
points.to(torch.float32),
points_first_idx.to(torch.long),
tris.to(torch.float32),
tris_first_idx.to(torch.long),
max_points,
)
point_to_face = point_to_face / (scale_fac**2)
return torch.sqrt(point_to_face)
device = "cpu"
target_obj_path = "dolphin.obj"
if not os.path.exists(target_obj_path):
# Reference: https://pytorch3d.org/tutorials/deform_source_mesh_to_target_mesh
src_url = (
"https://dl.fbaipublicfiles.com/pytorch3d/data/dolphin/dolphin.obj"
)
print(f"Downloading from {src_url}")
urllib.request.urlretrieve(
src_url,
"dolphin.obj",
)
verts, faces, aux = load_obj(target_obj_path)
faces_idx = faces.verts_idx.to(device)
verts = verts.to(device)
center = verts.mean(0)
verts = verts - center
scale = max(verts.abs().max(0)[0])
verts = verts / scale
target_mesh = Meshes(verts=[verts], faces=[faces_idx])
src_mesh = ico_sphere(4, device)
deform_verts = th.Vector(
tensor=src_mesh.verts_packed().reshape(1, -1),
name="deform_v",
)
target_v = th.Variable(verts.reshape(1, -1), name="target_v")
faces_idx = faces_idx.to(torch.float32)
target_f = th.Variable(faces_idx.reshape(1, -1), name="target_f")
def error_fn(optim_vars, aux_vars):
(verts,) = optim_vars
target_v, target_f = aux_vars
p2m = point_to_mesh_distance(
verts.tensor.reshape(-1, 3).to(torch.float32),
mesh_v=target_v.tensor.reshape(-1, 3).to(torch.float32),
mesh_f=target_f.tensor.reshape(-1, 3),
).to(torch.float64)
return p2m.unsqueeze(0)
optim_vars = (deform_verts,)
aux_vars = target_v, target_f
cost_function = th.AutoDiffCostFunction(
optim_vars,
error_fn,
deform_verts.shape[1] / 3,
aux_vars=aux_vars,
name="l2",
)
# grad_points, grad_tris = _C.point_face_dist_backward(
# RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
cost_function.jacobians()
When cost_function.jacobians() is called, it throws an exception. Full error below:
Traceback (most recent call last):
File "/Users/sonny/jac_theseus.py", line 227, in <module>
cost_function.jacobians()
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/theseus/core/cost_function.py", line 355, in jacobians
jacobians_full = self._compute_autograd_jacobian_vmap(
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/theseus/core/cost_function.py", line 341, in _compute_autograd_jacobian_vmap
return vmap(jacrev(jac_fn, argnums=0))(optim_tensors, aux_tensors)
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/apis.py", line 188, in wrapped
return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 278, in vmap_impl
return _flat_vmap(
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 44, in fn
return f(*args, **kwargs)
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 391, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 609, in wrapper_fn
flat_jacobians_per_input = compute_jacobian_stacked()
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 540, in compute_jacobian_stacked
chunked_result = vmap(vjp_fn)(basis)
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/apis.py", line 188, in wrapped
return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 278, in vmap_impl
return _flat_vmap(
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 44, in fn
return f(*args, **kwargs)
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 391, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 336, in wrapper
result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 124, in _autograd_grad
grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/autograd/__init__.py", line 411, in grad
result = Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
return user_fn(self, *args)
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/autograd_function.py", line 123, in backward
result = autograd_function.backward(ctx, *grads)
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/autograd/function.py", line 570, in wrapper
outputs = fn(ctx, *args)
File "/Users/sonny/jac_theseus.py", line 96, in backward
grad_points, grad_tris = _C.point_face_dist_backward(
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
I don't know how to start debugging this as _C.point_face_dist_backward happens in CUDA/CPU code. If you have any pointers, please let me know. Thanks a lot!
Just in intuition based on the traceback, especially here:
File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/theseus/core/cost_function.py", line 341, in _compute_autograd_jacobian_vmap return vmap(jacrev(jac_fn, argnums=0))(optim_tensors, aux_tensors)
I think that jac_fn will be called in the same way that the forward path is called with vmap. So, there might be an error in the jac_fn function's vmapping, which you might specify in the same way as you did for the forward path.
@TimoRST Thanks for the inputs! I never thought of it and will look into it.
Just one follow up, jac_fn is either the cost_function which is th.AutoDiffCostFunction or the error_fn I wrote that I used to compute the loss values. Do you know how to add vmap to those functions? They are not torch.autograd.Function.
Thanks!
I don't know that. I would first debug into that function to see if it really is called with those batched tensors which cause the error. After verifying you might just make an autodiff function out of that function?
@TimoRST Thanks. I will try it out.
Can I ask what your use case was to use knn and theseus? Was it also used in the context of an optimization that needed a full jacobian?
Thanks!
I wanted to implement something like WGICP (https://arxiv.org/abs/2209.09777), but I couldn't scale it because my graphics card was too small, so I didn't get comparable results. I didn't use the Jacobian, so it was enough to ensure correct vmapping in the forward path.
@TimoRST Thanks for the details. Really appreciate the help!
@bottler could you take a look at the code I posted above? I am going to follow @TimoRST suggestion to take a look at the error function. Meanwhile you find anything in my code, please let me know. Thanks!
Here are some findings. Conclusion: _C.point_face_dist_backward cannot take in grad_dists that is a 2D BatchedVector which is created by vmap/jacrev.
- When calling
torch.autograd.functional.jacobian, it computes jacobian row-by-row.grad_distsis a 1D vector. Everything works fine. - When calling
jacrev, internally it usesvmap, and it makesgrad_diststo be a batched version, meaning 2DBatchedVector, which makes sense because the vectorized version pushes the for loop into C++ code. And_C.point_face_dist_backwardis called with this kind of 2Dgrad_dists, it throwsRuntimeError: Cannot access data pointer of Tensor that doesn't have storage.
Here is a hacked version to verify my point though the math is probably wrong. When using jacrev, "un-vmap" the v_grad_dists and hack it with a for loop to compute row-by-row. It works fine. Then when it returns, pytorch3d complains the returned grad_points has wrong shape.
@staticmethod
@once_differentiable
def backward(ctx, grad_dists, grad_idxs):
grad_dists = grad_dists.contiguous()
points, tris, idxs = ctx.saved_tensors
min_triangle_area = ctx.min_triangle_area
# https://discuss.pytorch.org/t/save-batchedtensor-to-a-pickle-file/170561/4
v_points = torch._C._functorch.get_unwrapped(points)
v_tris = torch._C._functorch.get_unwrapped(tris)
v_idxs = torch._C._functorch.get_unwrapped(idxs)
v_grad_dists = torch._C._functorch.get_unwrapped(grad_dists)
grad_points = []
grad_tris = None
for v_grad_dists_v in v_grad_dists:
v_grad_points, v_grad_tris = _C.point_face_dist_backward(
v_points, v_tris, v_idxs, v_grad_dists_v, min_triangle_area
)
grad_points.append(v_grad_points)
if grad_tris is not None:
grad_tris = grad_tris + v_grad_tris
else:
grad_tris = v_grad_tris
grad_points = torch.cat(grad_points, dim=1)
return grad_points, None, grad_tris, None, None, None
@bottler would you be able to confirm what I said above is correct? And if this is the case, it seems changing internal code of _C.point_face_dist_backward is the only option?
Thanks!
I don't know enough I'm afraid about functorch