catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

[BUG] [FRONTEND] nested autograph functions results in autograph warnings

Open josh146 opened this issue 1 year ago • 0 comments

Shortcut: https://app.shortcut.com/xanaduai/story/51540/nested-autograph-conversions-fail

When using nested functions with autograph=True, various autograph warnings will be raised when calling the outermost function. However, autograph still appears to be working correctly.

For example, consider:

@qjit(autograph=True)
def f(x):
    if x > 5:
        y = x ** 2
    else:
        y = x ** 3
    return y

@qjit(autograph=True)
def g(x, n):
    for i in range(n):
        x = x + f(x)
    return x

Calling the inner function works fine:

>>> f(0.4)
array(0.064)

However, calling the outermost function will result in warnings:

>>> g(0.4, 6)
WARNING:tensorflow:AutoGraph could not transform <function outer_factory.<locals>.inner_factory.<locals>.f_1 at 0x7ad76c5b3ac0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: closure mismatch, requested ('ag__',), but source function had ()
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function outer_factory.<locals>.inner_factory.<locals>.f_1 at 0x7ad76c5b3ac0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: closure mismatch, requested ('ag__',), but source function had ()
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
array(22.14135448)

Examining the jaxpr appears to indicate that Autograph is working correctly:

>>> g.jaxpr
{ lambda ; a:f64[] b:i64[]. let
    c:f64[] = for_loop[
      apply_reverse_transform=False
      body_jaxpr={ lambda ; d:i64[] e:f64[]. let
          f:bool[] = gt e 5.0
          g:f64[] = cond[
            branch_jaxprs=[{ lambda ; a:f64[] b:f64[]. let c:f64[] = integer_pow[y=2] a in (c,) }, { lambda ; a:f64[] b:f64[]. let c:f64[] = integer_pow[y=3] b in (c,) }]
          ] f e e
          h:f64[] = add e g
        in (h,) }
      body_nconsts=0
    ] 0 b 1 0 a
  in (c,) }

josh146 avatar Dec 05 '23 16:12 josh146