mitsuba3
mitsuba3 copied to clipboard
A silent issue with @dr.wrap_ad()
Summary
I was following the steps of this awesome tutorial, but the kernel kept crashing at the step where v = from_differential(M, u)
is being calculated. The crash would happen before I could run anything inside from_differential()
leading me to think there could be an issue with the wrapper. No warning or error message appears before the crash. Not sure if it is a bug or I am not using the functions properly.
I tried to setup a small reproducer of the problem to pin down where the bug happens. The issue only happens if I pass tensorXF to the decorated function generated from a tensorXF that was an output from another decorated function.
The following line of code also crashes with the same silent error at out = mult2(tensor)
:
import torch
import mitsuba as mi
import drjit as dr
mi.set_variant("cuda_ad_rgb")
mi.set_log_level(mi.LogLevel.Trace)
@dr.wrap_ad(source="drjit", target="torch")
def mult1(y):
x = torch.tensor([1]).cuda()
return y*x
@dr.wrap_ad(source="drjit", target="torch")
def mult2(y):
x = torch.tensor([1]).cuda()
return y*x
tensor = dr.zeros(mi.TensorXf, shape=(3, 1))
tensor = mult1(tensor)
tensor = tensor.array
tensor = mi.TensorXf(tensor, shape=(3,1))
out = mult2(tensor)
print(out)
System configuration
OS: Windows 10 CPU: 5800X GPU: 3080 Ti Python version: 3.8.10 pytorch: 2.0.0
Dr.Jit version: 0.4.1 Mitsuba version: 3.2.1
Proposed a workaround here.
Hi @h-OUS-e
Sorry for the delay, I'm catching up with all these issues.
I'm unable to reproduce this problem with the script you provided. I've tried both the most recent version on master
and v3.2.1 from PyPI. I find this rather surprising, I would also expect some amount of logs/error message for such a crash - even just a default Python exit message.
Are you using a custom Mitsuba build?