mitsuba3 icon indicating copy to clipboard operation
mitsuba3 copied to clipboard

is using Pytorch network in mi.Loop allowed?

Open brabbitdousha opened this issue 4 months ago • 1 comments

Summary

Hi, I am using pytorch with misuba3, and I need to do pytorch network inference in rendering, here is a pseudo code of my working flow: I am calling trainer.eval() (which has a pytorch network inside) in a mi.Loop, for example if I want to do neural importance sampling, I need to insert a network in rendering, and after the rendering is over, I will update the network using trainer.train()



class MyPathIntegrator(mi.SamplingIntegrator):
    """Simple path tracer with MIS + NEE."""
    
    def path_tracing(self,
               scene: mi.Scene,
               sampler: mi.Sampler,
               ray: mi.Ray3f,
               medium: mi.Medium = None, 
               active: bool = True
    ):
        
        loop = mi.Loop(name="Custom Path Tracer",
                       state=lambda: (sampler, ray, depth, cur_depth, L, β, η, active,
                                      prev_si, prev_bsdf_pdf, prev_bsdf_delta))

        loop.set_max_iterations(self.max_depth)

        while loop(active):

            
            test_pts = (pts - scene.bbox().min) / (scene.bbox().max - scene.bbox().min)
            output = trainer.eval(test_pts)
            
            cur_depth += 1
        #'''

        return (L, dr.neq(depth, 0), [pts, albedo, output])

    def sample(self,
               scene: mi.Scene,
               sampler: mi.Sampler,
               ray: mi.Ray3f,
               medium: mi.Medium = None, 
               active: bool = True
    ):
        # --------------------- Configure loop state ----------------------

        (color, mask, aov) = self.path_tracing(scene, sampler, ray, medium, active)


        return (color, mask, aov)


def run_render(scene_path, spp, h, w, device):

    # Register new integrator

    mi.register_integrator("mypath", lambda props: MyPathIntegrator(props))

    scene = mi.load_file(scene_path)
    # Render
    with dr.suspend_grad():
        for i in range(16):
            torch.cuda.synchronize()
            img, aov = mi.render(scene, spp=spp)
            torch.cuda.synchronize()

            trainer.train()

However, after updating the network in trainer.train(), the output of trainer.eval() in mi.Loop doesn't update.... I tried with #dr.set_flag(dr.JitFlag.VCallRecord, False) #dr.set_flag(dr.JitFlag.LoopRecord, False), everything becomes right, but this is much slower, so with these two flags enabled, using pytorch network in mi.Loop is not allowed? I only do network inference in mi.Loop, and I am not using differentiable rendering.

System configuration

System information:

OS: windows CPU: intel i9-13900H GPU: RTX 4060 laptop Python version: 3.9 CUDA version: 12.0 NVidia driver: 550.54.14

Dr.Jit version: 0.4.4 Mitsuba version: 3.5.0

brabbitdousha avatar Oct 06 '24 11:10 brabbitdousha