[BUG] Gradient compilation error for nested functions with custom gradient
Bug Description
Hi, When I try to compute gradients with a kernel calling a function that calls a function that has custom gradients, I get a compiler error. A direct call of the custom function works as expected. Please see the example below which is probably clearer than my explanations. I could not find in the documentation if it is a bug or a known limitation of warp.
Snippet:
import numpy as np
import warp as wp
wp.init()
@wp.func
def custom_length(v: wp.vec3f) -> wp.float32:
return wp.length(v)
@wp.func_grad(custom_length)
def adj_custom_length(v: wp.vec3f, adj_out: wp.float32):
wp.adjoint[v] += wp.normalize(v) * adj_out
@wp.func
def nested_length(v: wp.vec3f) -> wp.float32:
return custom_length(v)
@wp.kernel
def test_kernel(vectors: wp.array(dtype=wp.vec3f), lengths: wp.array(dtype=wp.float32)):
idx = wp.tid()
# This leads to a compiler error when computing adjoint
lengths[idx] = nested_length(vectors[idx])
# A direct call works as expected however
#lengths[idx] = custom_length(vectors[idx])
vecs = wp.array(np.array([[0., 2., 0.], [3., 0., 0.], [0., 0., -1.0]], dtype=np.float32), dtype=wp.vec3f)
lengths = wp.zeros(vecs.size, dtype=wp.float32)
wp.launch(test_kernel, dim=vecs.size, inputs=[vecs], outputs=[lengths])
print(lengths.numpy())
lengths_grad = wp.array(np.array([1.0, 1.0, 1.0], dtype=np.float32), dtype=wp.float32)
vecs_grad = wp.zeros_like(vecs)
wp.launch(test_kernel, dim=vecs.size, inputs=[vecs], outputs=[lengths], adj_outputs=[lengths_grad], adj_inputs=[vecs_grad], adjoint=True)
print(vecs_grad.numpy())
Error message:
Warp NVRTC compilation error 6: NVRTC_ERROR_COMPILATION (/builds/omniverse/warp/warp/native/warp.cu:3622)
nvrtc: warning: Architectures prior to '<compute/sm>_75' are deprecated and may be removed in a future release
wp___main___28c6ade.cu(80): error: identifier "adj_custom_length_0" is undefined
adj_custom_length_0(var_v, adj_v, adj_0);
^
wp___main___28c6ade.cu(66): warning #550-D: variable "var_0" was set but never used
wp::float32 var_0;
^
Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"
wp___main___28c6ade.cu(88): warning #177-D: function "adj_custom_length_0" was declared but never referenced
static CUDA_CALLABLE void adj_custom_length_0(
^
1 error detected in the compilation of "wp___main___28c6ade.cu".
Module __main__ 28c6ade load on device 'cuda:0' took 272.84 ms (error)
System Information
Ubuntu 22.04 Python 3.11.11 warp-lang 1.8.0 and 1.9.0
Thanks for reporting this and providing a simple repro!
Hi @rbregier, apologies for the late reply. A fix for this issue in 6695e68716751cfe82384b0ea3ff4371e56529df was just merged into the main branch and will be a part of the v1.10.1 release.