mitsuba3
mitsuba3 copied to clipboard
Incorrect custom BSDF behavior with torch
Summary
In custom BSDF, .torch()
does not read correct surface intersection information under cuda variant.
System configuration
System information:
OS: Ubuntu 22.04.1 LTS CPU: AMD Ryzen 9 5900X 12-Core Processor GPU: NVIDIA GeForce RTX 3090 Ti Python: 3.8.13 (default, Oct 21 2022, 23:50:54) [GCC 11.2.0] NVidia driver: 515.65.01 CUDA: 11.7.99 LLVM: 0.0.0
Dr.Jit: 0.3.2 Mitsuba: 3.1.1 Is custom build? False Compiled with: GNU 10.2.1 Variants: scalar_rgb scalar_spectral cuda_ad_rgb llvm_ad_rgb
Description
I am trying to write a custom BSDF that passes sampled wi and surface intersection (uv,wo) to a pytorch-written MLP to output the BSDF value. To convert uv,wo,wi to pytorch tensor, it works fine by calling .torch()
in scalar mode, but cuda mode behavior seems to be incorrect.
Take a simpler example, if I diffuse shade the surface by its uv: uv=si.uv; reflectance=mistuba.Color3f(uv[0],uv[1],1)
, the correct rendering will be like this:
However, if I convert the uv to torch tensor first then back
uv=si.uv.torch(); uv=mitsuba.Point2f(uv[...,0],uv[...,1])
, it will always take uv=(0,0):
Steps to reproduce
import torch
import drjit as dr
import mitsuba
mitsuba.set_variant('cuda_ad_rgb')
import matplotlib.pyplot as plt
# diffuse shader with reflectance given by surface uv
class MyBSDF(mitsuba.BSDF):
def __init__(self, props):
mitsuba.BSDF.__init__(self, props)
reflection_flags = mitsuba.BSDFFlags.SpatiallyVarying|mitsuba.BSDFFlags.DiffuseReflection|mitsuba.BSDFFlags.FrontSide | mitsuba.BSDFFlags.BackSide
self.m_components = [reflection_flags]
self.m_flags = reflection_flags
def sample(self, ctx, si, sample1, sample2, active=True):
# diffuse sampling
theta_o = dr.acos(dr.sqrt(sample2[0]))
phi_o = 2*dr.pi*sample2[1]
sin_theta_o,cos_theta_o = dr.sincos(theta_o)
sin_phi_o,cos_phi_o = dr.sincos(phi_o)
wo = mitsuba.Vector3f(sin_theta_o*cos_phi_o,sin_theta_o*sin_phi_o,cos_theta_o)
pdf = dr.clamp(mitsuba.Frame3f.cos_theta(wo),1e-5,1)*1/dr.pi
bs = mitsuba.BSDFSample3f()
bs.pdf = pdf
bs.sampled_component = mitsuba.UInt32(0)
bs.sampled_type = mitsuba.UInt32(+mitsuba.BSDFFlags.DiffuseReflection)
bs.wo = wo
bs.eta = 1.0
uv = si.uv
# convert to torch tensor then back
uv = uv.torch().reshape(-1,2)
uv = mitsuba.Point2f(uv[:,0],uv[:,1])
value = mitsuba.Color3f(uv[0],uv[1],1.0)
return (bs,value)
def eval(self, ctx, si, wo, active=True):
uv = si.uv
# convert to torch tensor then back
uv = uv.torch().reshape(-1,2)
uv = mitsuba.Point2f(uv[:,0],uv[:,1])
f = mitsuba.Color3f(uv[0],uv[1],1.0)
f = f * 1.0/dr.pi * mitsuba.Frame3f.cos_theta(wo)
return f
def pdf(self, ctx, si, wo, active=True):
pdf = dr.clamp(mitsuba.Frame3f.cos_theta(wo),1e-5,1)*1/dr.pi
return pdf
def eval_pdf(self, ctx, si, wo, active=True):
f = self.eval(ctx,si,wo,active)
pdf = self.pdf(ctx,si,wo,active)
return f,pdf
def to_string(self,):
return 'MyBSDF'
mitsuba.register_bsdf("mybsdf", lambda props: MyBSDF(props))
# create simple scene
scene = mitsuba.load_dict({
'type': 'scene',
'integrator': {
'type': 'direct',
},
'sensor': {
'type': 'perspective',
'fov_axis': 'smaller',
'fov': 17.5,
'to_world': mitsuba.ScalarTransform4f.look_at(
origin=[80,-80,50],
target=[0, 0, 10],
up=[-1, 1, 4]
),
'sampler': {
'type': 'independent',
'sample_count': 16
},
'film': {
'banner': False,
'type': 'hdrfilm',
'width': 640,
'height': 540,
}
},
'shape1': {
'type': 'rectangle',
'flip_normals': True,
'to_world': mitsuba.ScalarTransform4f.translate([20,-20,50]).scale([4,4,1]),
'emitter': {
'type': 'area',
'radiance': 125.0
},
'bsdf': {
'type': 'diffuse',
'reflectance': {
'type': 'rgb',
'value': 0.0
},
}
},
'shape2':
{
'type': 'rectangle',
'to_world': mitsuba.ScalarTransform4f.translate([0, 0, 5]).scale(10),
'bsdf': {
'type': 'mybsdf'
}
}
})
img = mitsuba.render(scene).torch()
plt.imshow(img.cpu().pow(1/2.2).clamp(0,1))
Running the code you provided produces the correct image (i.e. the first one you shared) on my end. Can you please provide your PyTorch version as well ?
pytorch 1.13.0 py3.8_cuda11.7_cudnn8.5.0_0
I found it can somehow be fixed by setting dr.set_flag(dr.JitFlag.VCallRecord, False)
. But the code will take very large gpu memory (4.9 G).
I was able to reproduce the issue on my end. This seems to happen only for versions of Pytorch >= 1.13.0
. While we look into this, a possible workaround would be to downgrade torch
to 1.12.1
.
I don't think this should ever work. Even if it did on older versions of PyTorch, it might have been some happy coincidence.
Enabling dr.set_flag(drjit.JitFlag.VCallRecord, False)
is a hard-requirement here because any call to .torch()
will trigger an evaluation of the variable it is called on. Variables should not be evaluated inside of recorded virtual function calls.
Happy to hear back from @bathal1 if you figure anything else out. I might have forgotten something else.