mitsuba3
mitsuba3 copied to clipboard
A silent issue with @dr.wrap_ad()
Summary
I was following the steps of this awesome tutorial, but the kernel kept crashing at the step where v = from_differential(M, u)
is being calculated. The crash would happen before I could run anything inside from_differential()
leading me to think there could be an issue with the wrapper. No warning or error message appears before the crash. Not sure if it is a bug or I am not using the functions properly.
I tried to setup a small reproducer of the problem to pin down where the bug happens. The issue only happens if I pass tensorXF to the decorated function generated from a tensorXF that was an output from another decorated function.
The following line of code also crashes with the same silent error at out = mult2(tensor)
:
import torch
import mitsuba as mi
import drjit as dr
mi.set_variant("cuda_ad_rgb")
mi.set_log_level(mi.LogLevel.Trace)
@dr.wrap_ad(source="drjit", target="torch")
def mult1(y):
x = torch.tensor([1]).cuda()
return y*x
@dr.wrap_ad(source="drjit", target="torch")
def mult2(y):
x = torch.tensor([1]).cuda()
return y*x
tensor = dr.zeros(mi.TensorXf, shape=(3, 1))
tensor = mult1(tensor)
tensor = tensor.array
tensor = mi.TensorXf(tensor, shape=(3,1))
out = mult2(tensor)
print(out)
System configuration
OS: Windows 10 CPU: 5800X GPU: 3080 Ti Python version: 3.8.10 pytorch: 2.0.0
Dr.Jit version: 0.4.1 Mitsuba version: 3.2.1