Unexpected result in `@dr.wrap`
Question
Hello!
I'm trying to use the PyTorch interoperability in Dr.Jit. I find that in the document of Dr.Jit, the example code in Interoperability-Differentiability doesn't get the expected result. Here is an example:
import drjit as dr
import torch
dr.set_backend('cuda')
torch.set_default_device('cuda')
@dr.wrap(source="torch", target="drjit")
@dr.syntax
def pow2_wrong(n, x):
i, n = dr.auto.ad.Int(0), dr.auto.ad.Int(n)
while dr.hint(i < n, max_iterations=10):
x *= x
i += 1
return x
n = torch.tensor([0, 1, 2, 3], dtype=torch.int32)
x = torch.tensor([4, 4, 4, 4], dtype=torch.float32, requires_grad=True)
y_wrong = pow2_wrong(n, x)
print("y_wrong:", y_wrong)
y_wrong.sum().backward()
print("x.grad:", x.grad)
This prints:
y_wrong: tensor([4.0000e+00, 1.6000e+01, 2.5600e+02, 6.5536e+04], device='cuda:0',
grad_fn=<TorchWrapperBackward>)
x.grad_wrong: tensor([1., 1., 1., 1.], device='cuda:0')
But x.grad should be tensor([1.0000e+00, 8.0000e+00, 2.5600e+02, 1.3107e+05], device='cuda:0')
And I can get the correct result by modifying the pow2 function:
@dr.wrap(source="torch", target="drjit")
@dr.syntax
def pow2_correct(n, x):
power = dr.power(2, n)
result = dr.power(x, power)
return result
n = torch.tensor([0, 1, 2, 3], dtype=torch.int32)
x = torch.tensor([4, 4, 4, 4], dtype=torch.float32, requires_grad=True)
y_correct = pow2_correct(n, x)
print("y_correct:", y_correct)
y_correct.sum().backward()
print("x.grad_correct:", x.grad)
This prints:
y_correct: tensor([4.0000e+00, 1.6000e+01, 2.5600e+02, 6.5536e+04], device='cuda:0',
grad_fn=<TorchWrapperBackward>)
x.grad_correct: tensor([1.0000e+00, 8.0000e+00, 2.5600e+02, 1.3107e+05], device='cuda:0')
Platform
- Windows 11
- Python 3.12.9
- Dr.Jit b656011
- Torch 2.6.0+cu124
- nvcc
$ nvcc --version nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2024 NVIDIA Corporation Built on Thu_Mar_28_02:30:10_Pacific_Daylight_Time_2024 Cuda compilation tools, release 12.4, V12.4.131 Build cuda_12.4.r12.4/compiler.34097967_0
I don't know if the information is enough. Please tell me if you need more information.
Loops currently have significant limitations regarding differentiability. They can only be differentiated in forward mode, or in reverse mode if it is a simple loop (https://drjit.readthedocs.io/en/latest/autodiff.html#differentiating-loops).
So this isn't an issue of @dr.wrap but rather of the looping construct. This will be addressed in a future version of Dr.Jit. It's bad that there is no error message, and this should be fixed in the meantime.
Thanks for your quick reply!
This will be addressed in a future version of Dr.Jit.
What about fix the document first? It's truly confusing.