pytorch3d icon indicating copy to clipboard operation
pytorch3d copied to clipboard

CUDA Raymarcher Impl

Open d4l3k opened this issue 3 years ago • 3 comments

🚀 Feature

Add a CUDA implementation for common raymarcher methods such as EmissiveAbsorption, depth, etc.

I'm thinking about writing this myself but wanted to make a tracking issue for it in case other folks were interested

Motivation

Currently rendering using raymarching requires a large grid_sample input which consumes quite a bit of memory on the forward pass. For higher resolution image renders and with numerous camera view points this ends up using a huge amount of memory for just a single batch. Hitting 24gb of VRAM usage for a 15 camera viewpoint model with just a single batch.

If there was a CUDA implementation for the ray marcher we could skip the sampling grid and instead compute it on the fly while computing the EA cumprod

Pitch

https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/renderer/implicit/renderer.py#L168-L180

Implement the VolumeSampler and RayMarchers in CUDA via a single fused PT op.

We could implement the VolumeSampler and the ray marchers as two ops which might provide a bit more flexibility at the trade off of medium memory usage

This is pretty important when directly training from video as you need multiple camera view points to train. For any later projections we do need the full h x w so just sampling a smaller set of points won't work either

d4l3k avatar Sep 11 '22 01:09 d4l3k

One other interesting opportunity with a custom raymarcher is that it would be possible to implement subvoxels potentially by making the sampling aware of per voxel sizes/offsets

d4l3k avatar Sep 11 '22 23:09 d4l3k

Chatted with @bottler @kjchalup about this -- sounds like there's interest in it and seems useful so I'm going to take a stab at implementing it. How it fits in pytorch3d with the rest is still undecided

d4l3k avatar Sep 12 '22 19:09 d4l3k

I figured out a better way of handling this on the memory side. You can compute a partial graph for each camera and thus prune the leaves of the graph

grid = self.model(batch)
graph_grid = grid
grid = grid.detach()
grid.requires_grad = True

for cam in cameras:
   loss = render_loss(grid, cam)
   # add to grid gradient
   loss.backward()

# compute gradients on model
graph_grid.backward(grid.grad)

This AFAIK is equivalent to calling backward on the sum of the individual losses but it's much more memory efficient. I was able to cut the memory usage from 24GB to 13GB since the memory usage is independent of the number of camera views rendered :)

d4l3k avatar Sep 15 '22 07:09 d4l3k