pytorch3d icon indicating copy to clipboard operation
pytorch3d copied to clipboard

Faster vertex normals

Open wpalfi opened this issue 3 years ago • 1 comments

Meshes._compute_vertex_normals() calculates the same face normals three times. Here is an example that is 30% faster:

from pytorch3d.utils import ico_sphere
import torch
from timeit import timeit

#original code from https://github.com/facebookresearch/pytorch3d/blob/v0.4.0/pytorch3d/structures/meshes.py#L805
def compute_vertex_normals_original(meshes):
    faces_packed = meshes.faces_packed()
    verts_packed = meshes.verts_packed()
    verts_normals = torch.zeros_like(verts_packed)
    vertices_faces = verts_packed[faces_packed]

    verts_normals = verts_normals.index_add(
        0,
        faces_packed[:, 1],
        torch.cross(
            vertices_faces[:, 2] - vertices_faces[:, 1],
            vertices_faces[:, 0] - vertices_faces[:, 1],
            dim=1,
        ),
    )
    verts_normals = verts_normals.index_add(
        0,
        faces_packed[:, 2],
        torch.cross(
            vertices_faces[:, 0] - vertices_faces[:, 2],
            vertices_faces[:, 1] - vertices_faces[:, 2],
            dim=1,
        ), #same as cross above
    )
    verts_normals = verts_normals.index_add(
        0,
        faces_packed[:, 0],
        torch.cross(
            vertices_faces[:, 1] - vertices_faces[:, 0],
            vertices_faces[:, 2] - vertices_faces[:, 0],
            dim=1,
        ), #again the same
    )

    return torch.nn.functional.normalize(
        verts_normals, eps=1e-6, dim=1
    )

def compute_vertex_normals_wpalfi(meshes):
    faces_packed = meshes.faces_packed()
    verts_packed = meshes.verts_packed()
    verts_normals = torch.zeros_like(verts_packed)
    vertices_faces = verts_packed[faces_packed]

    faces_normals = torch.cross(
        vertices_faces[:, 2] - vertices_faces[:, 1],
        vertices_faces[:, 0] - vertices_faces[:, 1],
        dim=1,
    )

    verts_normals.index_add_(0, faces_packed[:, 0], faces_normals)
    verts_normals.index_add_(0, faces_packed[:, 1], faces_normals)
    verts_normals.index_add_(0, faces_packed[:, 2], faces_normals)
    
    return torch.nn.functional.normalize(
        verts_normals, eps=1e-6, dim=1
    )

mesh = ico_sphere(7, 'cuda')

normals_original = compute_vertex_normals_original(mesh)
normals_wpalfi = compute_vertex_normals_wpalfi(mesh)

maxdiff = (normals_wpalfi-normals_original).abs().max().item()
print(f"{maxdiff=}")

to = timeit(lambda:compute_vertex_normals_original(mesh).sum().item(),number=1000)
tw = timeit(lambda:compute_vertex_normals_wpalfi(mesh).sum().item(),number=1000)
print(f"original {to:.02f} ms")
print(f"wpalfi {tw:.02f} ms")

Output:

maxdiff=1.7881393432617188e-07
original 3.42 ms
wpalfi 2.37 ms

wpalfi avatar Jun 29 '21 15:06 wpalfi

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

github-actions[bot] avatar Jul 30 '21 05:07 github-actions[bot]

Finally got some time to implement this :) Thank you for pointing it out!

kjchalup avatar Aug 24 '22 05:08 kjchalup