mitsuba3 icon indicating copy to clipboard operation
mitsuba3 copied to clipboard

mi.load_dict produce segment fault in pytorch backward staticmethod

Open andyyankai opened this issue 1 year ago • 2 comments

ISSUES:

  1. Have to call
dr.set_flag(dr.JitFlag.LoopRecord, False)
dr.set_flag(dr.JitFlag.VCallRecord, False)
mi.set_variant('cuda_ad_rgb')

again in backward pass, otherwise mitsuba will say no variant set.

  1. call sampler = mi.load_dict({'type': 'independent'}) during backward cause seg error

full code to run:

import drjit as dr
import mitsuba as mi
from drjit.cuda import Float as FloatC, Matrix4f as Matrix4fC, Array3f as Vector3fC
from drjit.cuda.ad import Int as IntD, Float as FloatD, Matrix4f as Matrix4fD, Array3f as Vector3fD, Array2f as Vector2fD, Array1f as Vector1fD
import torch

dr.set_flag(dr.JitFlag.LoopRecord, False)
dr.set_flag(dr.JitFlag.VCallRecord, False)
mi.set_variant('cuda_ad_rgb')
class RenderFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, dummy):
        return dummy

    @staticmethod
    def backward(ctx, grad_out, *args):
        # have to recall these
        dr.set_flag(dr.JitFlag.LoopRecord, False)
        dr.set_flag(dr.JitFlag.VCallRecord, False)
        mi.set_variant('cuda_ad_rgb')

        with dr.suspend_grad():
            print("111 render")
            sampler = mi.load_dict({'type': 'independent'})  # produce error
            print("111")

class Renderer(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, dummy):
        image = RenderFunction.apply(dummy)
        return image

dummy_torch = torch.tensor([0.0], device='cuda', dtype=torch.float32, requires_grad=True)
psdr_render = Renderer()
opt_map = psdr_render.forward(dummy_torch)     
opt_map.backward()
exit()

andyyankai avatar Jul 08 '23 05:07 andyyankai

Just a notice that the segment fault happens at Thread::thread()->set_file_resolver(new FileResolver(*fs_backup)); in load_dict function, whenever it is called

andyyankai avatar Jul 08 '23 10:07 andyyankai

Hi @andyyankai

I'm not familiar with the pytorch execution model. My impression is that it is executing this code from another thread - not the main Python thread. This would explain why you need to set the variant again (it's a thread-local information).

Re-importing mitsuba in the backward() call might just do the trick. Or maybe pytorch has some configuration which doesn't swap threads?

njroussel avatar Jul 13 '23 06:07 njroussel