drjit icon indicating copy to clipboard operation
drjit copied to clipboard

dr.syntax: AD gets disabled by variable use in while loop

Open dvicini opened this issue 1 year ago • 7 comments

I have some code that mixes AD with handwritten derivatives & loops (e.g., similar to something like PRB)

It appears that within dr.syntax, any use of a differentiable variable within a loop disables the variable's AD graph.

Here is an example:

import drjit as dr

@dr.syntax
def f():

  i = dr.zeros(dr.llvm.Int32, 10)
  result = dr.zeros(dr.llvm.ad.Float, 10)

  a = dr.linspace(dr.llvm.ad.Float, 0, 1, 10)
  dr.enable_grad(a)
  print(dr.grad_enabled(a))

  with dr.suspend_grad():
    while i < 5:
      result += a

  print(dr.grad_enabled(a))

f()

This prints

True
False

But I would have expected this to print

True
True

The current behavior is a bit unintuitive, and leads to confusing loss of gradient tracking. To me it seems that the loop should have no influence on whether a has gradients enabled or not, similar to the pre-dr.syntax behavior.

dvicini avatar Aug 12 '24 16:08 dvicini

Hi @dvicini

It looks like the syntax rewriting is a bit aggressive here: it thinks a is part of the loop state. This means that a gets re-assigned at the end the loop and hence within the suspend_grad. Chaning the loop to while dr.hint(i < 5, exclude=[a]): fixes the issue.

I'll look into why it thinks a should be in the state. As a general rule, the @dr.syntax tends to be "safer" than necessary in order to guarantee the loop is valid, but that can come at a cost as you see here.

njroussel avatar Aug 14 '24 08:08 njroussel

Ok, there's not much we can do here, I believe.

In short, even though it's clear in your reproducer that a is not being written to in the loop, we cannot guarantee that in a more general case. For example, a random number generator will effectively never be on the left-hand side of an assignement (i.e sampler = sampler.next_1d()), but it must still be considered as part of the loop state in order to work because it might evolve implicitly (i.e sampler.next_1d()). So, in this case, we still consider a to be in the loop state because even though it's only use on the right-hand side of an assignment, it might have evolved implicitly.

The workflow for these kind of situations is usually to add print_code=True to @dr.syntax() and look at the rewritten function. It's usually fairly obvious if too many variables are included and that you should then specify then in the exclude list of a dr.hint statement.

njroussel avatar Aug 14 '24 10:08 njroussel

Fair enough, thanks for checking. I have to say I wasn't super aware of the various debug options for dr.syntax, but with print_code=True and the exclude hints, this should be okay in practice.

dvicini avatar Aug 14 '24 11:08 dvicini

The loop constructs, dr.syntax and dr.hint have extensive documentation. Please take a look and post an issue/PR if anything should be added. We should document any potential gotchas.

wjakob avatar Aug 17 '24 07:08 wjakob

I've ran into another seemingly related problem: The following code silently produces a result value of 0, instead of 1 as it should. If I remove the if-statement, it works as expected. I tried adding dr.hint(..., exclude=[b]), but that produces an error RuntimeError: ad_traverse(): tried to forward-propagate derivatives across edge a1 -> a2, which lies outside of the current dr.isolate_grad() scope

import drjit as dr
import mitsuba as mi

mi.set_variant('llvm_ad_rgb')

@dr.syntax(print_code=True)
def f():
  param = mi.Float(0.0)
  dr.enable_grad(param)
  dr.set_grad(param, 1.0)

  a = dr.linspace(mi.Float, 1, 2, 16) + param

  result = mi.Float(0.0)
  b = dr.gather(mi.Float, a, 3)
  if b == b: # Always true
    result += dr.forward_to(b)  # Fails silently

  # Doing the same without the if-statement works as expected
  # result += dr.forward_to(b)

  return result


result = f()
print(result)

dvicini avatar Oct 04 '24 14:10 dvicini

The last issue you posted is unrelated to the problem with dr.suspend_grad and symbolic operations. Could you create a separate issue for it? It's actually not quite sure how this should behave in general. Suppose we have an arbitrarily nested sequence of symbolic operations, and the user calls dr.forward at the innermost level. The system then has to kind of travel back in time and forward-propagate derivatives into each of the outer scopes.

The way we handled this in Mitsuba before is that you had to do a forward AD pass outside of the symbolic region, whose derivative values can then be picked up. But this of course isn't fully general.

wjakob avatar Oct 05 '24 07:10 wjakob

Yes you are right, I created a separate issue to track this: #295

dvicini avatar Oct 05 '24 08:10 dvicini

Fixed in #299

njroussel avatar Nov 19 '24 09:11 njroussel