drjit icon indicating copy to clipboard operation
drjit copied to clipboard

Sample code for integration with PyTorch?

Open hiroaki-santo opened this issue 2 years ago • 14 comments

I tried to use Mitsuba3 as a kind of rendering layer in a PyTorch pipeline. For the conversion of PyTorch and Dr. JIT tensors, I referred to:

  • https://github.com/mitsuba-renderer/drjit/pull/37
  • https://github.com/mitsuba-renderer/drjit/blob/master/tests/python/test_pytorch.py

However, it seems that these codes are not complete. Are there any sample codes for the integration with PyTorch?

Thank you.

hiroaki-santo avatar Aug 03 '22 04:08 hiroaki-santo

Hi @hiroaki-santo, I was able to do the conversion with simply torch.tensor(dr_tensor) a few months before (without AD). The other direction should also work (like mi_type(torch_tensor)) given a matching dimension.

When copying drjit tensor to pytorch, we need to make sure the evaluation of the tensor happens before the execution of the mem copy. eg:

 dr.eval(g_vec)
 dr.sync_thread()
 g_torch = torch.tensor(g_vec, device=torch.device('cuda'))

But I am not aware of this PR. There might be a new way to do this. Hope this helps.

ziyi-zhang avatar Aug 03 '22 07:08 ziyi-zhang

I have modified the code from this branch to work with cuda though i have not added the tests: code.

DoeringChristian avatar Aug 03 '22 07:08 DoeringChristian

Hi @ziyi-zhang, Thank you for your quick response. I didn't know I could use torch.tensor(). However, (as you edited, thanks), it does not work with AD.

Hi @DoeringChristian, Thank you for your codes! That's very helpful.

I have tested the module on Mitsuba 3.0.1 and DrJIT 0.2.1 and:

  1. When I used simple math computations in drjit and torch, backpropagation worked.
  2. I encountered some errors when I used it with the Mitsuba3 renderer.

My code is:

import mitsuba as mi

mi.set_variant("cuda_ad_rgb")

import torch
import drjit as dr

device = "cuda:0"
key = "red.reflectance.value"

scene = mi.load_dict(mi.cornell_box())
params = mi.traverse(scene)

image_ref = mi.render(scene, seed=0, spp=512)
image_ref_torch = image_ref.torch().to(device)  # target image

# learnable variable in torch
red_color = torch.ones(size=(1, 3)).to(torch.float32).to(device)
red_color.requires_grad = True

# convert to drjit and set to scene param
params[key] = from_torch(dr.cuda.ad.Array3f, red_color)
params.update()

# torch optimizer
opt = torch.optim.Adam([red_color])
for it in range(50):
    opt.zero_grad()

    # render in drjit
    rendered = mi.render(scene, params, spp=4)

    # convert drjit to torch
    dr.eval(rendered)
    dr.sync_thread()
    rendered_torch = to_torch(rendered)

    # loss in torch
    loss = torch.sum((rendered_torch - image_ref_torch) ** 2)
    loss.backward()  # ERROR!
    opt.step()

and got the error:

Critical Dr.Jit compiler failure: jit_optix_compile(): optixModuleGetCompilationState() indicates that the compilation did not complete succesfully. The module's compilation state is: 0x2363
Aborted (core dumped)

I would appreciate it if you could provide any comments/helps.

Thank you!

hiroaki-santo avatar Aug 04 '22 01:08 hiroaki-santo

It seems that I got this error at: https://github.com/DoeringChristian/drjit/blob/8d6b6cda7c84b85a4f8255494e0fdae4875f2a8c/drjit/torch.py#L28

hiroaki-santo avatar Aug 05 '22 02:08 hiroaki-santo

Sorry for responding so late. It seems that when compiling for optix the module does not compile. 0x2363 is the error code for OPTIX_MODULE_COMPILE_STATE_FAILED according to nvidia's documentation. I don't know why this is happening and I haven't figured out yet how to get the compile output from drjit. It works though for operations that don't need to compile optix modules for examle:

fc = nn.Linear(10, 1).cuda()

dropt = mi.ad.SGD(lr=0.1)
topt = torch.optim.SGD(fc.parameters(), lr=0.1)

a = dr.arange(mi.Float, 10)

dropt['a'] = a

for i in range(10):
    a = dropt['a']
    b = mi.Float(1.)
    
    c = a * b
    
    d = to_torch(c)
    e = fc(d)
    
    topt.zero_grad()
    e.backward()
    topt.step()
    
    dropt.step()

It also worked for me when using the llvm back end. Sorry if this does not help you directly but maybe if somebody who is more familiar with the inner workings of drjit can look into this that would be great.

DoeringChristian avatar Aug 09 '22 10:08 DoeringChristian

@DoeringChristian, Thank you for your reply. I confirm that your code with MLPs works without any errors in my environment. (SDG->SGD, opt->dropt)

I guess the computations in mi.render() cause the errors. I'm not familiar with Optix and not sure whether I can figure out the causes. Any help would be appreciated!

hiroaki-santo avatar Aug 09 '22 16:08 hiroaki-santo

Thanks for reporting those errors, I will get this experimental branch to work with LLVM and CUDA, and then investigate the Optix crash.

Speierers avatar Aug 17 '22 06:08 Speierers

By the way, the evaluation and synchronization shouldn't be necessary as Dr.Jit will already do this internally when converting a Dr.Jit array to a tensor (e.g. numpy.array or torch.Tensor)

Speierers avatar Aug 17 '22 08:08 Speierers

@Speierers , Thank you very much for looking into this issue!

I would like to try the from_to_torch branch. Is this branch compatible with the latest Mistuba3? I compiled the master version (e4cfa92218c0e2081bfc05be009659cd654caf36) of Mitsuba3 with:

git pullall          # https://mitsuba.readthedocs.io/en/latest/src/developer_guide/compiling.html#sec-compiling
cd ext/drjit && git checkout from_to_torch          # no error without this

However, I got the import error of drjit.torch during the compile:

Traceback (most recent call last):
  File "/root/mitsuba3/resources/generate_stub_files.py", line 297, in <module>
    import mitsuba as mi        
  File "/root/mitsuba3/build/python/mitsuba/__init__.py", line 8, in <module>
    import drjit as dr          
  File "/root/mitsuba3/build/python/drjit/__init__.py", line 45, in <module>
    import drjit.torch as torch # noqa 
ModuleNotFoundError: No module named 'drjit.torch'   
[1129/1130] Building CXX object src/integrators/CMakeFiles/volpathmis.dir/volpathmis.cpp.o   
ninja: build stopped: subcommand failed. 

hiroaki-santo avatar Aug 18 '22 02:08 hiroaki-santo

@hiroaki-santo for it to work with Mitsuba 3 you need to change the following in src/python/CMakeLists.txt:

set(DRJIT_PYTHON_FILES
    __init__.py const.py detail.py generic.py
-    matrix.py router.py traits.py tensor.py
+    matrix.py router.py traits.py tensor.py torch.py
  )

Speierers avatar Aug 18 '22 08:08 Speierers

Thank you for your helps. I can comiple it.

hiroaki-santo avatar Aug 18 '22 08:08 hiroaki-santo

    

why we need the dropt step? I tried to remove the dropt step and it still works; the loss still goes down.

Sorry for responding so late. It seems that when compiling for optix the module does not compile. 0x2363 is the error code for OPTIX_MODULE_COMPILE_STATE_FAILED according to nvidia's documentation. I don't know why this is happening and I haven't figured out yet how to get the compile output from drjit. It works though for operations that don't need to compile optix modules for examle:

fc = nn.Linear(10, 1).cuda()

dropt = mi.ad.SGD(lr=0.1)
topt = torch.optim.SGD(fc.parameters(), lr=0.1)

a = dr.arange(mi.Float, 10)

dropt['a'] = a

for i in range(10):
    a = dropt['a']
    b = mi.Float(1.)
    
    c = a * b
    
    d = to_torch(c)
    e = fc(d)
    
    topt.zero_grad()
    e.backward()
    topt.step()
    
    dropt.step()

It also worked for me when using the llvm back end. Sorry if this does not help you directly but maybe if somebody who is more familiar with the inner workings of drjit can look into this that would be great.

why we need the dropt step? I tried to remove the dropt step and it still works; the loss still goes down.


I guess this might be useful when we want the paramters of mitsuba also able to update

zhaoguangyuan123 avatar Aug 27 '22 11:08 zhaoguangyuan123

I guess the sample code demonstrates both torch->drjit and drjit->torch in one. The Optix errors occurred in my environment when I used mi.render().

hiroaki-santo avatar Aug 27 '22 13:08 hiroaki-santo

Just an add: I also got the same error when I tried to combine Pytorch and Mitsuba under 'cuda_ad_rgb' variant. This bug will not come out when I use 'llvm_ad' variant on the Mac OS.

Critical Dr.Jit compiler failure: jit_optix_compile(): optixModuleGetCompilationState() indicates that the compilation did not complete succesfully. The module's compilation state is: 0x2363

zhaoguangyuan123 avatar Aug 31 '22 12:08 zhaoguangyuan123

I apologize for not following up on this issue earlier. The new function @dr.wrap_ad() appears to have resolved the issue, and this tutorial provided exactly what I was looking for. Thank you, and I will colse this issue.

hiroaki-santo avatar Jan 30 '23 14:01 hiroaki-santo

Coule someone can explain why "module 'drjit' has no attribute 'from_torch' "?

linxxcad avatar Nov 27 '23 12:11 linxxcad

This discussion is outdated, it has been renamed since.

Here's most likely what you're looking for: https://drjit.readthedocs.io/en/latest/reference.html#drjit.wrap_ad

We even have tutorial using in Mitsuba: https://mitsuba.readthedocs.io/en/latest/src/inverse_rendering/pytorch_mitsuba_interoperability.html

njroussel avatar Nov 28 '23 07:11 njroussel