nerfstudio icon indicating copy to clipboard operation
nerfstudio copied to clipboard

Slow training with large patch_size=64, (nerfacto; poster)

Open sangminkim-99 opened this issue 2 years ago • 5 comments

Describe the bug The training procedure does not go well with nerfacto on the poster dataset when I set the patch_size=64. It works with smaller patch_size like 8 and works fine with the PixelSampler(patch_size=1).

To Reproduce Steps to reproduce the behavior:

  1. Set the patch_size=64
  2. Set the train_num_rays_per_batch larger than 4096 to guarantee the sampling procedure.

Expected behavior As NeRF does not care about the sequence of input rays, it should be okay with the large patch size. To the best of my knowledge, Instant-ngp even uses the whole image at a time. Would you happen to have any idea about this behavior??

Screenshots

patch_size Screenshot
1 Screenshot from 2023-07-25 19-18-46
8 Screenshot from 2023-07-25 19-16-45
64 Screenshot from 2023-07-25 19-14-57

sangminkim-99 avatar Jul 25 '23 10:07 sangminkim-99

Hi! What you're noticing is something pretty fundamental about the way NeRFs optimize; as you increase the patch size the number of "effective" rays is much lower, since there's a lot of redundancy in your samples. For example, a patch size of 8 divides your effective batch size by 64, meaning you can think of that configuration as training with 4096/64 = 64 rays per batch. This is really low! The reason you're seeing the blurriness is because training with such a low number of effective rays messes with convergence/stability. Some things that might help you:

  1. Try increasing the number of gradient accumulation steps to bump up the effective batch size back to 4096 (eg, for patch size 8 make gradient accumulation 64). This will slow convergence in terms of psnr/time, but it will probably resolve the blurry problem you're seeing after training. You probably also don't have to go as high as 64 accumulation steps, I would play around with parameters until you find a good balance of speed/patch size/quality.
  2. try increasing batch size a lot (this is probably hard memory-wise)

kerrj avatar Aug 01 '23 16:08 kerrj

Hi @kerrj, thanks for the explanation and the solutions you provided.

I understand the reason behind the slow convergence with larger patch sizes and the effectiveness of the gradient accumulation method.

However, I'm still curious about the impact of ray sampling strategy on the optimization process. As far as I know, patch sampling is a way of re-ordering the sampled rays. In the extreme case, it would be equivalent to using one-image-at-a-time, where the patch size is maximum.

I've noticed that torch-ngp and instant-ngp use this strategy effectively, which makes me wonder if there might be some other underlying issues causing the problem with nerfacto on the poster dataset.

I'd appreciate any further insights you can provide on this matter.

Thank you!

sangminkim-99 avatar Aug 01 '23 17:08 sangminkim-99

Hi @sangminkim-99, I also notice the same strange training ,result from large patch size config.But I'm curious whether the same behavior appears in torch-ngp(I didn't find patch size setting in instant-ngp). Have you done some experiments on torch-ngp? I'd appreciate any shared experiment results. I am not sure if it's because of the mistakes in my own code : /

lbh666 avatar Jan 21 '24 07:01 lbh666

Hi @lbh666,

Upon further review, I've realized that I misunderstood the sampling process in torch-ngp. Specifically, in torch-ngp, rays are sampled from a single image, but it's important to note that these samples originate from random pixel locations within that image.

I appreciate @kerrj's solution, which seems to be a solid starting point. Additionally, you might consider incorporating various loss functions designed for patches, such as SSIM and LPIPS. These functions can provide different gradient values for neighboring pixels, potentially mitigating the issue of chaotic gradient accumulation. At least in TensoRF, I found that SSIM loss can improve the reconstruction quality with 8x8 patches.

sangminkim-99 avatar Jan 22 '24 00:01 sangminkim-99

Hi @lbh666,

Upon further review, I've realized that I misunderstood the sampling process in torch-ngp. Specifically, in torch-ngp, rays are sampled from a single image, but it's important to note that these samples originate from random pixel locations within that image.

I appreciate @kerrj's solution, which seems to be a solid starting point. Additionally, you might consider incorporating various loss functions designed for patches, such as SSIM and LPIPS. These functions can provide different gradient values for neighboring pixels, potentially mitigating the issue of chaotic gradient accumulation. At least in TensoRF, I found that SSIM loss can improve the reconstruction quality with 8x8 patches.

Oh, thanks for your timely and brilliant advice. I will try it later : )

lbh666 avatar Jan 22 '24 07:01 lbh666