pytorch3d icon indicating copy to clipboard operation
pytorch3d copied to clipboard

SoftRasterizer enhancement to learn 3D from RGB only

Open monniert opened this issue 2 years ago • 2 comments

🚀 Feature

Add the layered blending function of https://arxiv.org/abs/2204.10310 to make SoftRasterizer work using RGB loss only

Motivation

SoftRasterizer does not work without silhouettes (see #359, #507, #839, #840, #1004). The work (appendix A) analyses why (TLDR: the original softmax-like blending function annihilates the differentiability wrt the opacity maps - and thus wrt the vertex positions - because opacity maps appear in both the numerator and denominator of the softmax) and presents a simple modification to SoftRas based on a new layered aggregation function. The resulting Layered SoftRasterizer can successfully learn from RGB loss only.

Pitch

A version of this layered blending function is implemented here. Similar to the original softmax_rgb_blend, this function layered_rgb_blend could be placed in pytorch3d/renderer/blending.py and handled by the different shaders given a dedicated argument.

Using the resulting shader in the fit_textured_mesh.ipynb tutorial enables to learn 3D meshes using RGB only. This zip file contains the ipynb + html versions of the tutorial showing that:

1. SoftRas with silhouette works

silhouette.png

2. SoftRas with silhouette + RGB works

silhouette_rgb.png

3. SoftRas with RGB diverges

rgb.png

4. Layered SoftRas with RGB works

rgb_layered.png

monniert avatar Jul 26 '22 16:07 monniert

Hello @monniert , thank you for the amazing work! Can u elaborate more on how to layered_rgb_blend function? I am currently trying to perform camera pose estimation with only RGB channels, and I noticed they cannot converge.

ykzzyk avatar Aug 31 '23 03:08 ykzzyk

Hi @ykzzyk, sure you should have a look at this file for an example on how to use this layered_rgb_blend function.

Specifically, you should redefine a pytorch3d Shader class (called LayeredShader) and reuse it inside the classical MeshRenderer class. I also noticed it do not converge for pose estimation, and in my case this fix makes the optimization for pose estimation work

monniert avatar Sep 02 '23 16:09 monniert