pytorch3d
pytorch3d copied to clipboard
Faster vertex normals
Meshes._compute_vertex_normals() calculates the same face normals three times. Here is an example that is 30% faster:
from pytorch3d.utils import ico_sphere
import torch
from timeit import timeit
#original code from https://github.com/facebookresearch/pytorch3d/blob/v0.4.0/pytorch3d/structures/meshes.py#L805
def compute_vertex_normals_original(meshes):
faces_packed = meshes.faces_packed()
verts_packed = meshes.verts_packed()
verts_normals = torch.zeros_like(verts_packed)
vertices_faces = verts_packed[faces_packed]
verts_normals = verts_normals.index_add(
0,
faces_packed[:, 1],
torch.cross(
vertices_faces[:, 2] - vertices_faces[:, 1],
vertices_faces[:, 0] - vertices_faces[:, 1],
dim=1,
),
)
verts_normals = verts_normals.index_add(
0,
faces_packed[:, 2],
torch.cross(
vertices_faces[:, 0] - vertices_faces[:, 2],
vertices_faces[:, 1] - vertices_faces[:, 2],
dim=1,
), #same as cross above
)
verts_normals = verts_normals.index_add(
0,
faces_packed[:, 0],
torch.cross(
vertices_faces[:, 1] - vertices_faces[:, 0],
vertices_faces[:, 2] - vertices_faces[:, 0],
dim=1,
), #again the same
)
return torch.nn.functional.normalize(
verts_normals, eps=1e-6, dim=1
)
def compute_vertex_normals_wpalfi(meshes):
faces_packed = meshes.faces_packed()
verts_packed = meshes.verts_packed()
verts_normals = torch.zeros_like(verts_packed)
vertices_faces = verts_packed[faces_packed]
faces_normals = torch.cross(
vertices_faces[:, 2] - vertices_faces[:, 1],
vertices_faces[:, 0] - vertices_faces[:, 1],
dim=1,
)
verts_normals.index_add_(0, faces_packed[:, 0], faces_normals)
verts_normals.index_add_(0, faces_packed[:, 1], faces_normals)
verts_normals.index_add_(0, faces_packed[:, 2], faces_normals)
return torch.nn.functional.normalize(
verts_normals, eps=1e-6, dim=1
)
mesh = ico_sphere(7, 'cuda')
normals_original = compute_vertex_normals_original(mesh)
normals_wpalfi = compute_vertex_normals_wpalfi(mesh)
maxdiff = (normals_wpalfi-normals_original).abs().max().item()
print(f"{maxdiff=}")
to = timeit(lambda:compute_vertex_normals_original(mesh).sum().item(),number=1000)
tw = timeit(lambda:compute_vertex_normals_wpalfi(mesh).sum().item(),number=1000)
print(f"original {to:.02f} ms")
print(f"wpalfi {tw:.02f} ms")
Output:
maxdiff=1.7881393432617188e-07
original 3.42 ms
wpalfi 2.37 ms
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.
Finally got some time to implement this :) Thank you for pointing it out!