catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

[BUG] ``qjit`` function cannot catch keyword arguments

Open tzunghanjuang opened this issue 1 year ago • 0 comments

Issue description

~~It looks like qjit is not well compitable with functools.partial.~~ ~~The following example triggers TypeError: fn() missing 1 required positional argument: 'y' error.~~

The main reason is that qjit-decorated function is not able to accept keyword arguments. It seems that the feature is not supported. The following example shows that partial without keyword arguments can still work. However, the ones with keyword argument fails.

from catalyst import qjit
import functools

@qjit
def fn(x, y):
    return x * y

functools.partial(fn, y=1)(3) # This fails!
functools.partial(fn, 3)(1)   # This works.
fn(x=3, y=1)                  # This fails!

Tracebacks


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 9
      6     return x * y
      8 #functools.partial(fn, y= 1)(3)
----> 9 fn(x=3, y=1)

File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Workspace/catalyst-clone/frontend/catalyst/jit.py:514, in QJIT.__call__(self, *args, **kwargs)
    511 if EvaluationContext.is_tracing():
    512     return self.user_function(*args, **kwargs)
--> 514 requires_promotion = self.jit_compile(args)
    516 # If we receive tracers as input, dispatch to the JAX integration.
    517 if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]):

File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Workspace/catalyst-clone/frontend/catalyst/jit.py:585, in QJIT.jit_compile(self, args)
    581 # Capture with the patched conversion rules
    582 with Patcher(
    583     (ag_primitives, "module_allowlist", self.patched_module_allowlist),
    584 ):
--> 585     self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args)
    587 self.mlir_module, self.mlir = self.generate_ir()
    588 self.compiled_function, self.qir = self.compile()

File ~/Workspace/catalyst-clone/frontend/catalyst/debug/instruments.py:143, in instrument.<locals>.wrapper(*args, **kwargs)
    140 @functools.wraps(fn)
    141 def wrapper(*args, **kwargs):
    142     if not InstrumentSession.active:
--> 143         return fn(*args, **kwargs)
    145     with ResultReporter(stage_name, has_finegrained) as reporter:
    146         fn_results, wall_time, cpu_time = time_function(fn, args, kwargs)

File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Workspace/catalyst-clone/frontend/catalyst/jit.py:643, in QJIT.capture(self, args)
    637 full_sig = merge_static_args(dynamic_sig, args, static_argnums)
    639 with Patcher(
    640     (qml.QNode, "__call__", QFunc.__call__),
    641 ):
    642     # TODO: improve PyTree handling
--> 643     jaxpr, out_type, treedef = trace_to_jaxpr(
    644         self.user_function, static_argnums, abstracted_axes, full_sig, {}
    645     )
    647 return jaxpr, out_type, treedef, dynamic_sig

File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Workspace/catalyst-clone/frontend/catalyst/jax_tracer.py:445, in trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs)
    440     with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
    441         make_jaxpr_kwargs = {
    442             "static_argnums": static_argnums,
    443             "abstracted_axes": abstracted_axes,
    444         }
--> 445         jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
    447 return jaxpr, out_type, out_treedef

File ~/Workspace/catalyst-clone/frontend/catalyst/jax_extras/tracing.py:586, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
    584     f, out_tree_promise = flatten_fun(f, in_tree)
    585     f = annotate(f, in_type)
--> 586     jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
    587 closed_jaxpr = ClosedJaxpr(jaxpr, consts)
    588 return closed_jaxpr, out_type, out_tree_promise()

File ~/.local/lib/python3.10/site-packages/jax/_src/profiler.py:336, in annotate_function.<locals>.wrapper(*args, **kwargs)
    333 @wraps(func)
    334 def wrapper(*args, **kwargs):
    335   with TraceAnnotation(name, **decorator_kwargs):
--> 336     return func(*args, **kwargs)
    337   return wrapper

File ~/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2324, in trace_to_jaxpr_dynamic2(fun, debug_info)
   2322 with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore
   2323   main.jaxpr_stack = ()  # type: ignore
-> 2324   jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2325   del main, fun
   2326 return jaxpr, out_type, consts

File ~/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2339, in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2337 in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
   2338 in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 2339 ans = fun.call_wrapped(*in_tracers_)
   2340 out_tracers = map(trace.full_raise, ans)
   2341 jaxpr, out_type, consts = frame.to_jaxpr2(out_tracers)

File ~/.local/lib/python3.10/site-packages/jax/_src/linear_util.py:191, in WrappedFun.call_wrapped(self, *args, **kwargs)
    188 gen = gen_static_args = out_store = None
    190 try:
--> 191   ans = self.f(*args, **dict(self.params, **kwargs))
    192 except:
    193   # Some transformations yield from inside context managers, so we have to
    194   # interrupt them before reraising the exception. Otherwise they will only
    195   # get garbage-collected at some later time, running their cleanup tasks
    196   # only after this exception is handled, which can corrupt the global
    197   # state.
    198   while stack:

TypeError: fn() missing 2 required positional arguments: 'x' and 'y'

tzunghanjuang avatar Jun 14 '24 17:06 tzunghanjuang