taichi-nerfs
taichi-nerfs copied to clipboard
Autodiff runtime Error
I want to change the kernel function at step ray_marching since I want to get the gradient for xyzs. I reimplement it like volumn rendering but I get the runtime error.
`
class RayMarchingRenderer(torch.nn.Module):
def __init__(self):
super(RayMarchingRenderer, self).__init__()
self._raymarching_rendering_kernel = raymarching_train_kernel
class _module_function(torch.autograd.Function):
@staticmethod
def forward(
ctx,
rays_o,
rays_d,
hits_t,
density_bitfield,
cascades,
scale,
exp_step_factor,
grid_size,
max_samples
):
noise = torch.rand_like(rays_o[:, 0])
counter = torch.zeros(
2,
device=rays_o.device,
dtype=torch.int32
)
rays_a = torch.empty(
rays_o.shape[0], 3,
device=rays_o.device,
dtype=torch.int32,
)
xyzs = torch.empty(
rays_o.shape[0] * max_samples, 3,
device=rays_o.device,
dtype=torch_type,
requires_grad=True
)
dirs = torch.empty(
rays_o.shape[0] * max_samples, 3,
device=rays_o.device,
dtype=torch_type,
requires_grad=True
)
deltas = torch.empty(
rays_o.shape[0] * max_samples,
device=rays_o.device,
dtype=torch_type,
)
ts = torch.empty(
rays_o.shape[0] * max_samples,
device=rays_o.device,
dtype=torch_type,
)
raymarching_train_kernel(
rays_o,
rays_d,
hits_t,
density_bitfield,
noise,
counter,
rays_a,
xyzs,
dirs,
deltas,
ts,
cascades, grid_size, scale,
exp_step_factor, max_samples
)
# total samples for all rays
total_samples = counter[0]
# remove redundant output
xyzs = xyzs[:total_samples]
dirs = dirs[:total_samples]
deltas = deltas[:total_samples]
ts = ts[:total_samples]
ctx.save_for_backward(
rays_o,
rays_d,
hits_t,
density_bitfield,
noise,
counter,
rays_a,
xyzs,
dirs,
deltas,
ts,
)
ctx.cascades=cascades
ctx.grid_size=grid_size
ctx.scale=scale
ctx.exp_step_factor=exp_step_factor
ctx.max_samples=max_samples
return rays_a, xyzs, dirs, deltas, ts, total_samples
@staticmethod
def backward(
ctx,
dL_drays_a,
dL_dxyzs,
dL_ddirs,
dL_ddeltas,
dL_dts,
dL_dtotal_samples
):
cascades=ctx.cascades
grid_size=ctx.grid_size
scale=ctx.scale
exp_step_factor=ctx.exp_step_factor
max_samples=ctx.max_samples
(
rays_o,
rays_d,
hits_t,
density_bitfield,
noise,
counter,
rays_a,
xyzs,
dirs,
deltas,
ts,
) = ctx.saved_tensors
# put the gradients into the tensors before calling the grad kernel
rays_a.grad = dL_drays_a
xyzs.grad = dL_dxyzs
dirs.grad = dL_ddirs
deltas.grad=dL_ddeltas
ts.grad =dL_dts
# total_samples.grad=dL_dtotal_samples
self._raymarching_rendering_kernel.grad(
rays_o,
rays_d,
hits_t,
density_bitfield,
noise,
counter,
rays_a,
xyzs,
dirs,
deltas,
ts,
cascades, grid_size, scale,
exp_step_factor,max_samples
)
return rays_o.grad, rays_d.grad, None, None, None, None, None, xyzs.grad, dirs.grad, deltas.grad, ts.grad, None, None, None, None, None
self._module_function = _module_function.apply
def forward(
self,
rays_o,
rays_d,
hits_t,
density_bitfield,
cascades,
scale,
exp_step_factor,
grid_size,
max_samples
):
return self._module_function(
rays_o.contiguous(),
rays_d.contiguous(),
hits_t.contiguous(),
density_bitfield,
cascades,
scale,
exp_step_factor,
grid_size,
max_samples
)
`