mitsuba3
mitsuba3 copied to clipboard
mi.load_dict produce segment fault in pytorch backward staticmethod
ISSUES:
- Have to call
dr.set_flag(dr.JitFlag.LoopRecord, False)
dr.set_flag(dr.JitFlag.VCallRecord, False)
mi.set_variant('cuda_ad_rgb')
again in backward pass, otherwise mitsuba will say no variant set.
- call
sampler = mi.load_dict({'type': 'independent'})
during backward cause seg error
full code to run:
import drjit as dr
import mitsuba as mi
from drjit.cuda import Float as FloatC, Matrix4f as Matrix4fC, Array3f as Vector3fC
from drjit.cuda.ad import Int as IntD, Float as FloatD, Matrix4f as Matrix4fD, Array3f as Vector3fD, Array2f as Vector2fD, Array1f as Vector1fD
import torch
dr.set_flag(dr.JitFlag.LoopRecord, False)
dr.set_flag(dr.JitFlag.VCallRecord, False)
mi.set_variant('cuda_ad_rgb')
class RenderFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, dummy):
return dummy
@staticmethod
def backward(ctx, grad_out, *args):
# have to recall these
dr.set_flag(dr.JitFlag.LoopRecord, False)
dr.set_flag(dr.JitFlag.VCallRecord, False)
mi.set_variant('cuda_ad_rgb')
with dr.suspend_grad():
print("111 render")
sampler = mi.load_dict({'type': 'independent'}) # produce error
print("111")
class Renderer(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, dummy):
image = RenderFunction.apply(dummy)
return image
dummy_torch = torch.tensor([0.0], device='cuda', dtype=torch.float32, requires_grad=True)
psdr_render = Renderer()
opt_map = psdr_render.forward(dummy_torch)
opt_map.backward()
exit()
Just a notice that the segment fault happens at
Thread::thread()->set_file_resolver(new FileResolver(*fs_backup));
in load_dict function, whenever it is called
Hi @andyyankai
I'm not familiar with the pytorch execution model. My impression is that it is executing this code from another thread - not the main Python thread. This would explain why you need to set the variant again (it's a thread-local information).
Re-importing mitsuba in the backward()
call might just do the trick. Or maybe pytorch has some configuration which doesn't swap threads?