gaussian-splatting
gaussian-splatting copied to clipboard
Fix inconsistent training results with RGBA/PNG images
Issue summary
The training relies on PIL to resize the input images and extracts the resized alpha to mask the rendered image during training. Since PIL pre-multiplies the resized RGB with the resized alpha, the training produces different Gaussian points depending on whether the input get resized or not. Moreover, the extracted alpha channel from PIL is not perfectly binarized, causing floaters around the edges. The issue has been going around in #1039, #1121, and #1114 since they trained with either PNG images or a dataset containing masks in the 4th channel (preprocessed DTU, NeRF Synthetic).
The fix is self-contained in the PILtoTorch function. It checks if the input is of type RGBA and manually masks the RGB channels. This alpha channel is then discarded and the process continues as if the input was RGB, making the alpha multiplication step in the train script a no-op.
Details
In the current commit, here's how a ground truth RGBA is treated during the training:
- The loaded image is resized to some resolution with
PIL.Image.Image.resizein thePILtoTorchfunction. - The RGB channels of the result are extracted and becomes
gt_imageintrain.py(viaCamera.original_image). - The resized alpha channel is saved separatedly as
alpha_mask. This mask is then multiplied with the renderedimageintrain.pyand the loss is called on thegt_imageand the maskedimage.
If the input RGBA is actually resized in PILtoTorch (the resolution param is different from the image's resolution), PIL automatically multiplies the resized RGB with the resized alpha:
RGB before resize |
RGB after resize |
|---|---|
This creates two different scenarios:
- If the image is not resized (no
-rflag), the RGB ground truth is the original image without masking, and the savedalpha_maskis perfectly binarized. - If the image is resized, the RGB ground truth is masked, but the saved
alpha_maskis distorted along edges.
Scenario 1: no resize
The Gaussian points undergoes tension during training since they get masked before getting fed into the loss but the ground truth is the original image:
| GT | Render (Iter 7000) |
|---|---|
Scenario 2: RGBA is resized
The resized alpha_mask is not perfectly binarized along the edge due to interpolation. This imperfect mask is multiplied with the rendered image, causing floaters:
| GT | Render -r 2 (Iter 7000) |
|---|---|
The fix
To minimize the modification, when PILtoTorch encounters RGBA, we manually extract and mask the RGB channels and let the input become this new masked RGB. The remaining logic is as-is and the alpha multiplication step in train.py becomes no-op.
| Render (Iter 7000) | Render -r 2 (Iter 7000) |
|---|---|
Test environment
- Python
3.9 - PyTorch
2.4.0 - CUDA
12.4 - MSVC
19.43
Notes
The render.py might need fixed to export the masked GT (rather than the original RGB) when running on trained model with original resolution settings (no - r).