Forward-mode gradients with symbolic conditionals
Separating this from issue #253
The following code silently produces a result value of 0, instead of 1 as it should. If I remove the if-statement, it works as expected. I tried adding dr.hint(..., exclude=[b]), but that produces an error RuntimeError: ad_traverse(): tried to forward-propagate derivatives across edge a1 -> a2, which lies outside of the current dr.isolate_grad() scope.
It's not entirely clear what's the right pattern is to get correct results here. I guess one option is to explicitly forward-propagate up to right before the if-conditioned is entered?
import drjit as dr
import mitsuba as mi
mi.set_variant('llvm_ad_rgb')
@dr.syntax(print_code=True)
def f():
param = mi.Float(0.0)
dr.enable_grad(param)
dr.set_grad(param, 1.0)
a = dr.linspace(mi.Float, 1, 2, 16) + param
result = mi.Float(0.0)
b = dr.gather(mi.Float, a, 3)
# dr.forward_to(b) # One option is to explicitly propagate up to
if b == b: # Always true
result += dr.forward_to(b) # Fails silently
# Doing the same without the if-statement works as expected
# result += dr.forward_to(b)
return result
result = f()
print(result)
Do you have thoughts on how you would like this to behave? (following the comment here)
I am honestly not quite sure, thinking about it some more, this seems quite tricky to solve robustly.
I am still trying to see how to best use the if-statements and am not sure yet what usage patterns will emerge.
Maybe it's more of a matter of putting a warning about this in the doc, if not there already