zipnerf-pytorch icon indicating copy to clipboard operation
zipnerf-pytorch copied to clipboard

Pose refinement

Open ilbash opened this issue 2 years ago • 8 comments

Hello, @SuLvXiangXin! Your track_linearize module has stopped gradients through it. If I will delete it - nans will appear. How to add pose refinement module and avoid nans?

ilbash avatar Jul 03 '23 21:07 ilbash

Lol, NaN's solved by change this line

weights = torch.erf(1 / torch.sqrt(8 * stds[..., None] ** 2 * self.encoder.grid_sizes ** 2))

to

weights = torch.erf(1 / (8 ** 0.5 * (stds[..., None] * self.encoder.grid_sizes).abs().clamp_min(EPS))))

So, finally we can propagate gradients through track_linearize!

ilbash avatar Jul 04 '23 09:07 ilbash

Great! If you add the pose refinement successfully, please let me know

SuLvXiangXin avatar Jul 04 '23 14:07 SuLvXiangXin

@SuLvXiangXin I successfully did pose refinement via mip-nerf 360 scene contraction implemented here. Basically now I try to analyse your contraction function (still nans in gradients). Also I have a question why you divide here to x_mag_sq? Also I did not understand this line.

If you have some explanations, I will happy to see it!

torch.autograd.detect_anomaly() output with your contraction function:

 File "/home/user/nerfstudio/nerfstudio/field_components/spatial_distortions.py", line 111, in contract
    std_contracted = torch.where(mask[..., 0], cov, det_13[..., 0] * cov)
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

ilbash avatar Jul 06 '23 15:07 ilbash

@Ilyabasharov The first line is same as the equation 12 in zipnerf paper. The second line is same as the equation 14 in zipnerf paper, while $|Jacobi|$ can be calculated explicitly using eigenvalues.

SuLvXiangXin avatar Jul 10 '23 02:07 SuLvXiangXin

@SuLvXiangXin I've solved nan gradients on contract function. see my PR to nerfstudio https://github.com/nerfstudio-project/nerfstudio/pull/2242. the idea is

...
# prevent negative root computations
clamped_mag = mag.clamp_min(1.0)
det_13 = (torch.pow(2 * clamped_mag - 1, 1 / 3) / clamped_mag) ** 2
...

Thank you for your help!

ilbash avatar Jul 18 '23 21:07 ilbash

Great!

SuLvXiangXin avatar Jul 19 '23 16:07 SuLvXiangXin

EPS = ?

datasciritwik avatar Sep 29 '23 20:09 datasciritwik

@datasciritwik small value like 1e-8

ilbash avatar Oct 02 '23 12:10 ilbash