mvsnerf icon indicating copy to clipboard operation
mvsnerf copied to clipboard

output NaN when fine-tuning on my own dataset

Open XinWu98 opened this issue 2 years ago • 5 comments

Hi, thanks for releasing your work! I have some problems when fine-tuning on LLFF and my own dataset.

  1. I wonder why you set spheric_poses=True for LLFF? Besides, when I change it to False(which means to use ndc coordinates), the train_mvs_nerf_finetuning_pl.py gives a poor initial rendering(shown in fig.), and reports error below immediately. It seems that the networks output NaN when training, and show no enhancement after clipping the gradients(to avoid gradient explosion). So do you have any idea of solving this problem? Could there be a numerical error in the code?

00000000_00

  1. I try to fine-tuning on my own dataset, which is sparsely sampled from a real scene dataset and has more complex trajectory than LLFF. It reports the same error as below at the early fine-tuning stage. If it's not caused by numerical error, does it mean that your method is unsuitable for complex posed images of real scene? However, in my comprehension, such scenes should be trained as long time as Nerf in this situation, rather than report NaN when training, right?

  2. Do you have any advice on how to choose source views?For example, should it be very close neighbors or uniformly distributed around the scene? How much co-visibility between source views is proper for your method?

[W python_anomaly_mode.cpp:104] Warning: Error detected in PowBackward0. Traceback of forward call that caused the error:
  File "train_mvs_nerf_finetuning_pl.py", line 309, in <module>
    trainer.fit(system)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 460, in fit
    self._run(model)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 758, in _run
    self.dispatch()
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 799, in dispatch
    self.accelerator.start_training(self)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in run_stage
    return self.run_train()
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 871, in run_train
    self.train_loop.run_training_epoch()
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 499, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 738, in run_training_batch
    self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 434, in optimizer_step
    model_ref.optimizer_step(
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1403, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 214, in step
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 134, in __optimizer_step
    trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 329, in optimizer_step
    self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in run_optimizer_step
    self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 193, in optimizer_step
    optimizer.step(closure=lambda_closure, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 67, in wrapper
    return wrapped(*args, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/torch/optim/adam.py", line 66, in step
    loss = closure()
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 732, in train_step_and_backward_closure
    result = self.training_step_and_backward(
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 823, in training_step_and_backward
    result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 290, in training_step
    training_step_output = self.trainer.accelerator.training_step(args)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 204, in training_step
    return self.training_type_plugin.training_step(*args)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 155, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
  File "train_mvs_nerf_finetuning_pl.py", line 165, in training_step
    img_loss = img2mse(rgbs, rgbs_target)
  File "/home/wuxin/OriginDoc/PycharmProjects/mvsnerf-main/utils.py", line 10, in <lambda>
    img2mse = lambda x, y : torch.mean((x - y) ** 2)
 (function _print_stack)
Traceback (most recent call last):
  File "train_mvs_nerf_finetuning_pl.py", line 309, in <module>
    trainer.fit(system)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 460, in fit
    self._run(model)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 758, in _run
    self.dispatch()
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 799, in dispatch
    self.accelerator.start_training(self)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in run_stage
    return self.run_train()
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 871, in run_train
    self.train_loop.run_training_epoch()
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 499, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 738, in run_training_batch
    self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 434, in optimizer_step
    model_ref.optimizer_step(
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1403, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 214, in step
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 134, in __optimizer_step
    trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 329, in optimizer_step
    self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in run_optimizer_step
    self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 193, in optimizer_step
    optimizer.step(closure=lambda_closure, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 67, in wrapper
    return wrapped(*args, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/torch/optim/adam.py", line 66, in step
    loss = closure()
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 732, in train_step_and_backward_closure
    result = self.training_step_and_backward(
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 836, in training_step_and_backward
    self.backward(result, optimizer, opt_idx)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 869, in backward
    result.closure_loss = self.trainer.accelerator.backward(
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 308, in backward
    output = self.precision_plugin.backward(
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 79, in backward
    model.backward(closure_loss, optimizer, opt_idx)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1275, in backward
    loss.backward(*args, **kwargs)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.
Exception ignored in: <function tqdm.__del__ at 0x7fa58a1585e0>
Traceback (most recent call last):
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/tqdm/std.py", line 1122, in __del__
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/tqdm/std.py", line 1335, in close
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/tqdm/std.py", line 1514, in display
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/tqdm/std.py", line 1125, in __repr__
  File "/home/wuxin/anaconda3/lib/python3.8/site-packages/tqdm/std.py", line 1475, in format_dict
TypeError: cannot unpack non-iterable NoneType object

XinWu98 avatar Jul 17 '21 11:07 XinWu98

Hi XinWu, 1) currently the code is not supported "spheric_poses=False" since the near-far boundary must align with the cost volume construction, the cost volume is built on the real-world coordinate, so you can't normal the near-fat boundary. 2) did you using a new loader or something? NaN values are generally caused by the bugs, 3) very close neighboring views can perform better. Thanks.

apchenstu avatar Jul 19 '21 01:07 apchenstu

Hi XinWu, 1) currently the code is not supported "spheric_poses=False" since the near-far boundary must align with the cost volume construction, the cost volume is built on the real-world coordinate, so you can't normal the near-fat boundary. 2) did you using a new loader or something? NaN values are generally caused by the bugs, 3) very close neighboring views can perform better. Thanks.

Thanks for your reply! I will try the sampling strategy mentioned in 3). As for 2), I write a new class of Dataset refering to LLFFDataset in llff.py, and replace its images, poses and depth bounds with my own dataset. I will continue to check if there are any bugs or dirty data. However, my dataset is sampled from a video that records an real indoor scene from random viewpoints, rather than face forward like LLFF. Many views of training set may not be included in the cost volume built with 3 images, thus some 3D coordinates could not find corresponding volume feature . I wonder that did you take experiments in such situation? Do you think it may lead to the collapse of fine-tuning(e.g. NaN values)?

XinWu98 avatar Jul 19 '21 14:07 XinWu98

Hi XinWu, 1) currently the code is not supported "spheric_poses=False" since the near-far boundary must align with the cost volume construction, the cost volume is built on the real-world coordinate, so you can't normal the near-fat boundary. 2) did you using a new loader or something? NaN values are generally caused by the bugs, 3) very close neighboring views can perform better. Thanks.

Thanks for your reply! I will try the sampling strategy mentioned in 3). As for 2), I write a new class of Dataset refering to LLFFDataset in llff.py, and replace its images, poses and depth bounds with my own dataset. I will continue to check if there are any bugs or dirty data. However, my dataset is sampled from a video that records an real indoor scene from random viewpoints, rather than face forward like LLFF. Many views of training set may not be included in the cost volume built with 3 images, thus some 3D coordinates could not find corresponding volume feature . I wonder that did you take experiments in such situation? Do you think it may lead to the collapse of fine-tuning(e.g. NaN values)?

Hi, I met with the same NaN problem like you... Have you found the solution?

Lemon-XQ avatar Sep 24 '21 06:09 Lemon-XQ

@Lemon-XQ and @XinWu98 I might have found the problem. I experienced the same error as you did when I trained a generalized MVSNeRF on my own dataset. I found the following line to be the root for the problem:

https://github.com/apchenstu/mvsnerf/blob/1fdf6487389d0872dade614b3cea61f7b099406e/utils.py#L129

what happens here is that the sampling points are transformed to the reference camera NDC coordinate system. However, in the case that there are big angles between the source and target cameras, it might happen (at random), that some sampling points have z==0 in the reference camera coordinate system. This causes a division by 0 -> ndc coordinate goes to infinity -> when using F.grid_sample this gives a nan raw output and everything goes to hell.

As a workaround, I added the following lines to https://github.com/apchenstu/mvsnerf/blob/1fdf6487389d0872dade614b3cea61f7b099406e/utils.py#L124 image

this should prevent the division by zero.

malteprinzler avatar Jun 30 '22 16:06 malteprinzler

btw, if you use your own dataset, make sure to change the defaults for near and far in https://github.com/apchenstu/mvsnerf/blob/1fdf6487389d0872dade614b3cea61f7b099406e/utils.py#L112

they are not always adjusted according to your dataset specification!

malteprinzler avatar Jun 30 '22 16:06 malteprinzler