mitsuba3 icon indicating copy to clipboard operation
mitsuba3 copied to clipboard

Incorrect custom BSDF behavior with torch

Open william122742 opened this issue 2 years ago • 5 comments

Summary

In custom BSDF, .torch() does not read correct surface intersection information under cuda variant.

System configuration

System information:

OS: Ubuntu 22.04.1 LTS CPU: AMD Ryzen 9 5900X 12-Core Processor GPU: NVIDIA GeForce RTX 3090 Ti Python: 3.8.13 (default, Oct 21 2022, 23:50:54) [GCC 11.2.0] NVidia driver: 515.65.01 CUDA: 11.7.99 LLVM: 0.0.0

Dr.Jit: 0.3.2 Mitsuba: 3.1.1 Is custom build? False Compiled with: GNU 10.2.1 Variants: scalar_rgb scalar_spectral cuda_ad_rgb llvm_ad_rgb

Description

I am trying to write a custom BSDF that passes sampled wi and surface intersection (uv,wo) to a pytorch-written MLP to output the BSDF value. To convert uv,wo,wi to pytorch tensor, it works fine by calling .torch() in scalar mode, but cuda mode behavior seems to be incorrect.

Take a simpler example, if I diffuse shade the surface by its uv: uv=si.uv; reflectance=mistuba.Color3f(uv[0],uv[1],1), the correct rendering will be like this: image However, if I convert the uv to torch tensor first then back uv=si.uv.torch(); uv=mitsuba.Point2f(uv[...,0],uv[...,1]), it will always take uv=(0,0): image

Steps to reproduce

import torch
import drjit as dr
import mitsuba
mitsuba.set_variant('cuda_ad_rgb')
import matplotlib.pyplot as plt

# diffuse shader with reflectance given by surface uv
class MyBSDF(mitsuba.BSDF):
    def __init__(self, props):
        mitsuba.BSDF.__init__(self, props)
        reflection_flags   = mitsuba.BSDFFlags.SpatiallyVarying|mitsuba.BSDFFlags.DiffuseReflection|mitsuba.BSDFFlags.FrontSide | mitsuba.BSDFFlags.BackSide
        self.m_components  = [reflection_flags]
        self.m_flags = reflection_flags

    def sample(self, ctx, si, sample1, sample2, active=True):
        # diffuse sampling
        theta_o = dr.acos(dr.sqrt(sample2[0]))
        phi_o = 2*dr.pi*sample2[1]
        sin_theta_o,cos_theta_o = dr.sincos(theta_o)
        sin_phi_o,cos_phi_o = dr.sincos(phi_o)
        wo = mitsuba.Vector3f(sin_theta_o*cos_phi_o,sin_theta_o*sin_phi_o,cos_theta_o)

        pdf = dr.clamp(mitsuba.Frame3f.cos_theta(wo),1e-5,1)*1/dr.pi
        bs = mitsuba.BSDFSample3f()
        bs.pdf = pdf
        bs.sampled_component = mitsuba.UInt32(0)
        bs.sampled_type = mitsuba.UInt32(+mitsuba.BSDFFlags.DiffuseReflection)
        bs.wo = wo
        bs.eta = 1.0
        uv = si.uv
        # convert to torch tensor then back
        uv = uv.torch().reshape(-1,2)
        uv = mitsuba.Point2f(uv[:,0],uv[:,1])
        value = mitsuba.Color3f(uv[0],uv[1],1.0)
        return (bs,value)

    def eval(self, ctx, si, wo, active=True):
        uv = si.uv
        # convert to torch tensor then back
        uv = uv.torch().reshape(-1,2)
        uv = mitsuba.Point2f(uv[:,0],uv[:,1])
        f = mitsuba.Color3f(uv[0],uv[1],1.0)
        f = f * 1.0/dr.pi * mitsuba.Frame3f.cos_theta(wo)
        return f

    def pdf(self, ctx, si, wo, active=True):
        pdf = dr.clamp(mitsuba.Frame3f.cos_theta(wo),1e-5,1)*1/dr.pi
        return pdf

    def eval_pdf(self, ctx, si, wo, active=True):
        f = self.eval(ctx,si,wo,active)
        pdf = self.pdf(ctx,si,wo,active)
        return f,pdf

    def to_string(self,):
        return 'MyBSDF'

mitsuba.register_bsdf("mybsdf", lambda props: MyBSDF(props))


# create simple scene
scene = mitsuba.load_dict({
    'type': 'scene',
    'integrator': {
        'type': 'direct',
    },
    'sensor': {
        'type': 'perspective',
        'fov_axis': 'smaller',
        'fov': 17.5,
        'to_world': mitsuba.ScalarTransform4f.look_at(
            origin=[80,-80,50],
            target=[0, 0, 10],
            up=[-1, 1, 4]
        ),
        'sampler': {
            'type': 'independent',
            'sample_count': 16
        },
        'film': {
            'banner': False,
            'type': 'hdrfilm',
            'width': 640,
            'height': 540,
        }
    },
    'shape1': {
        'type': 'rectangle',
        'flip_normals': True,
        'to_world': mitsuba.ScalarTransform4f.translate([20,-20,50]).scale([4,4,1]),
        'emitter': {
            'type': 'area',
            'radiance': 125.0
        },
        'bsdf': {
            'type': 'diffuse',
            'reflectance': {
                'type': 'rgb',
                'value': 0.0
            },
        }
    },
    'shape2':
    {
        'type': 'rectangle',
        'to_world':  mitsuba.ScalarTransform4f.translate([0, 0, 5]).scale(10),
        'bsdf': {
            'type': 'mybsdf'
        }
    }
})

img = mitsuba.render(scene).torch()
plt.imshow(img.cpu().pow(1/2.2).clamp(0,1))

william122742 avatar Dec 16 '22 21:12 william122742

Running the code you provided produces the correct image (i.e. the first one you shared) on my end. Can you please provide your PyTorch version as well ?

bathal1 avatar Dec 19 '22 14:12 bathal1

pytorch 1.13.0 py3.8_cuda11.7_cudnn8.5.0_0

william122742 avatar Dec 20 '22 03:12 william122742

I found it can somehow be fixed by setting dr.set_flag(dr.JitFlag.VCallRecord, False). But the code will take very large gpu memory (4.9 G).

william122742 avatar Jan 14 '23 06:01 william122742

I was able to reproduce the issue on my end. This seems to happen only for versions of Pytorch >= 1.13.0. While we look into this, a possible workaround would be to downgrade torch to 1.12.1.

bathal1 avatar Mar 01 '23 10:03 bathal1

I don't think this should ever work. Even if it did on older versions of PyTorch, it might have been some happy coincidence.

Enabling dr.set_flag(drjit.JitFlag.VCallRecord, False) is a hard-requirement here because any call to .torch() will trigger an evaluation of the variable it is called on. Variables should not be evaluated inside of recorded virtual function calls.

Happy to hear back from @bathal1 if you figure anything else out. I might have forgotten something else.

njroussel avatar Mar 01 '23 12:03 njroussel