torch-ngp copied to clipboard
Backpropagating through raymarching
Hi @ashawkey , I wish to backpropagate gradients back to camera poses, however this results in an error in _march_rays_train(), which does not have a backward pass function. To make this work do I need to write the backward pass function in raymarching py and the backward function in or is there another easier way to reach that goal ?
I am also interested in learning how to add gradient flow such that rays_o and rays_d have requires_grad = True so that backward pass can be used to optimize camera pose like in iNeRF or more recently NVIDIA's parallel inversion of NeRF's paper.
@alex3dfan Hi, yes, you have to implement the backward
for raymarching, so the gradient from xyzs
and dirs
can propagate to rays_o
and rays_d
, which finally get to your trainable camera poses.
I haven't been able to implement and test it recently, but you may check ngp_pl, where they implemented it.
hi @ashawkey ,
I am writing the backward pass for raymarching following ngp_pl. I am wondering if deltas[:,1] ( in your repo correspond to ts ( in ngp_pl repo? I quickly compared two cuda implementations and it seems to me that they are the same thing, can you help confirm this? If this is the case, I think it's pretty straight-forward to implement the backward pass.
Thanks for this open-source contribution!
Best, Shengyu
This is my current implementation:
from torch_scatter import segment_csr
from einops import rearrange
class _march_rays_train(Function):
def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1,
perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
''' march rays to generate points (forward only)
rays_o/d: float, [N, 3]
bound: float, scalar
density_bitfield: uint8: [CHHH // 8]
C: int
H: int
nears/fars: float, [N]
step_counter: int32, (2), used to count the actual number of generated points.
mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
perturb: bool
align: int, pad output so its size is dividable by align, set to -1 to disable.
force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
dirs: float, [M, 3], all generated points' view dirs.
deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)
rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 1] + rays[i, 2]] --> points belonging to rays[i, 0]
if not rays_o.is_cuda:
rays_o = rays_o.cuda()
if not rays_d.is_cuda:
rays_d = rays_d.cuda()
if not density_bitfield.is_cuda:
density_bitfield = density_bitfield.cuda()
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
density_bitfield = density_bitfield.contiguous()
N = rays_o.shape[0] # num rays
M = N * max_steps # init max points number in total
# running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
# It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
if not force_all_rays and mean_count > 0:
if align > 0:
mean_count += align - mean_count % align
M = mean_count
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
if step_counter is None:
step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
if perturb:
noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
_backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars,
xyzs, dirs, deltas, rays, step_counter,
noises) # m is the actually used points number
# print(step_counter, M)
# only used at the first (few) epochs.
if force_all_rays or mean_count <= 0:
m = step_counter[0].item() # D2H copy
if align > 0:
m += align - m % align
xyzs = xyzs[:m]
dirs = dirs[:m]
deltas = deltas[:m]
# torch.cuda.empty_cache()
ctx.save_for_backward(rays.long(), deltas[:,1])
return xyzs, dirs, deltas, rays
# we follow the implementation of ngp_pl
def backward(ctx, dL_dxyzs, dL_ddirs,
dL_ddeltas, dL_drays_a):
rays_a, ts = ctx.saved_tensors
segments =[rays_a[:, 1], rays_a[-1:, 1]+rays_a[-1:, 2]])
dL_drays_o = segment_csr(dL_dxyzs, segments)
dL_drays_d = segment_csr(dL_dxyzs*rearrange(ts, 'n -> n 1')+dL_ddirs, segments)
return dL_drays_o, dL_drays_d, None, None, None, None, None, None, None, None, None, None, None, None, None
@ShengyuH Hi, Shengyu, I wonder if the gradient works properly for camera optimization with the code you provided above? I am trying to do something similar but still encountered NotImplementedError: You must implement either the backward or vjp method for your custom autograd.Function to use it with backward mode AD.
I'm not sure if the issue is with this function. Looking forward to your reply!