catalyst
catalyst copied to clipboard
[BUG] [FRONTEND] nested autograph functions results in autograph warnings
Shortcut: https://app.shortcut.com/xanaduai/story/51540/nested-autograph-conversions-fail
When using nested functions with autograph=True, various autograph warnings will be raised when calling the outermost function. However, autograph still appears to be working correctly.
For example, consider:
@qjit(autograph=True)
def f(x):
if x > 5:
y = x ** 2
else:
y = x ** 3
return y
@qjit(autograph=True)
def g(x, n):
for i in range(n):
x = x + f(x)
return x
Calling the inner function works fine:
>>> f(0.4)
array(0.064)
However, calling the outermost function will result in warnings:
>>> g(0.4, 6)
WARNING:tensorflow:AutoGraph could not transform <function outer_factory.<locals>.inner_factory.<locals>.f_1 at 0x7ad76c5b3ac0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: closure mismatch, requested ('ag__',), but source function had ()
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function outer_factory.<locals>.inner_factory.<locals>.f_1 at 0x7ad76c5b3ac0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: closure mismatch, requested ('ag__',), but source function had ()
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
array(22.14135448)
Examining the jaxpr appears to indicate that Autograph is working correctly:
>>> g.jaxpr
{ lambda ; a:f64[] b:i64[]. let
c:f64[] = for_loop[
apply_reverse_transform=False
body_jaxpr={ lambda ; d:i64[] e:f64[]. let
f:bool[] = gt e 5.0
g:f64[] = cond[
branch_jaxprs=[{ lambda ; a:f64[] b:f64[]. let c:f64[] = integer_pow[y=2] a in (c,) }, { lambda ; a:f64[] b:f64[]. let c:f64[] = integer_pow[y=3] b in (c,) }]
] f e e
h:f64[] = add e g
in (h,) }
body_nconsts=0
] 0 b 1 0 a
in (c,) }