mitsuba3
mitsuba3 copied to clipboard
is using Pytorch network in mi.Loop allowed?
Summary
Hi, I am using pytorch with misuba3, and I need to do pytorch network inference in rendering, here is a pseudo code of my working flow: I am calling trainer.eval() (which has a pytorch network inside) in a mi.Loop, for example if I want to do neural importance sampling, I need to insert a network in rendering, and after the rendering is over, I will update the network using trainer.train()
class MyPathIntegrator(mi.SamplingIntegrator):
"""Simple path tracer with MIS + NEE."""
def path_tracing(self,
scene: mi.Scene,
sampler: mi.Sampler,
ray: mi.Ray3f,
medium: mi.Medium = None,
active: bool = True
):
loop = mi.Loop(name="Custom Path Tracer",
state=lambda: (sampler, ray, depth, cur_depth, L, β, η, active,
prev_si, prev_bsdf_pdf, prev_bsdf_delta))
loop.set_max_iterations(self.max_depth)
while loop(active):
test_pts = (pts - scene.bbox().min) / (scene.bbox().max - scene.bbox().min)
output = trainer.eval(test_pts)
cur_depth += 1
#'''
return (L, dr.neq(depth, 0), [pts, albedo, output])
def sample(self,
scene: mi.Scene,
sampler: mi.Sampler,
ray: mi.Ray3f,
medium: mi.Medium = None,
active: bool = True
):
# --------------------- Configure loop state ----------------------
(color, mask, aov) = self.path_tracing(scene, sampler, ray, medium, active)
return (color, mask, aov)
def run_render(scene_path, spp, h, w, device):
# Register new integrator
mi.register_integrator("mypath", lambda props: MyPathIntegrator(props))
scene = mi.load_file(scene_path)
# Render
with dr.suspend_grad():
for i in range(16):
torch.cuda.synchronize()
img, aov = mi.render(scene, spp=spp)
torch.cuda.synchronize()
trainer.train()
However, after updating the network in trainer.train(), the output of trainer.eval() in mi.Loop doesn't update....
I tried with #dr.set_flag(dr.JitFlag.VCallRecord, False) #dr.set_flag(dr.JitFlag.LoopRecord, False)
, everything becomes right, but this is much slower,
so with these two flags enabled, using pytorch network in mi.Loop is not allowed? I only do network inference in mi.Loop, and I am not using differentiable rendering.
System configuration
System information:
OS: windows CPU: intel i9-13900H GPU: RTX 4060 laptop Python version: 3.9 CUDA version: 12.0 NVidia driver: 550.54.14
Dr.Jit version: 0.4.4 Mitsuba version: 3.5.0