catalyst
catalyst copied to clipboard
[BUG] ``qjit`` function cannot catch keyword arguments
Issue description
~~It looks like qjit is not well compitable with functools.partial.~~
~~The following example triggers TypeError: fn() missing 1 required positional argument: 'y' error.~~
The main reason is that qjit-decorated function is not able to accept keyword arguments. It seems that the feature is not supported. The following example shows that partial without keyword arguments can still work. However, the ones with keyword argument fails.
from catalyst import qjit
import functools
@qjit
def fn(x, y):
return x * y
functools.partial(fn, y=1)(3) # This fails!
functools.partial(fn, 3)(1) # This works.
fn(x=3, y=1) # This fails!
Tracebacks
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[4], line 9
6 return x * y
8 #functools.partial(fn, y= 1)(3)
----> 9 fn(x=3, y=1)
File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/Workspace/catalyst-clone/frontend/catalyst/jit.py:514, in QJIT.__call__(self, *args, **kwargs)
511 if EvaluationContext.is_tracing():
512 return self.user_function(*args, **kwargs)
--> 514 requires_promotion = self.jit_compile(args)
516 # If we receive tracers as input, dispatch to the JAX integration.
517 if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]):
File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/Workspace/catalyst-clone/frontend/catalyst/jit.py:585, in QJIT.jit_compile(self, args)
581 # Capture with the patched conversion rules
582 with Patcher(
583 (ag_primitives, "module_allowlist", self.patched_module_allowlist),
584 ):
--> 585 self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args)
587 self.mlir_module, self.mlir = self.generate_ir()
588 self.compiled_function, self.qir = self.compile()
File ~/Workspace/catalyst-clone/frontend/catalyst/debug/instruments.py:143, in instrument.<locals>.wrapper(*args, **kwargs)
140 @functools.wraps(fn)
141 def wrapper(*args, **kwargs):
142 if not InstrumentSession.active:
--> 143 return fn(*args, **kwargs)
145 with ResultReporter(stage_name, has_finegrained) as reporter:
146 fn_results, wall_time, cpu_time = time_function(fn, args, kwargs)
File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/Workspace/catalyst-clone/frontend/catalyst/jit.py:643, in QJIT.capture(self, args)
637 full_sig = merge_static_args(dynamic_sig, args, static_argnums)
639 with Patcher(
640 (qml.QNode, "__call__", QFunc.__call__),
641 ):
642 # TODO: improve PyTree handling
--> 643 jaxpr, out_type, treedef = trace_to_jaxpr(
644 self.user_function, static_argnums, abstracted_axes, full_sig, {}
645 )
647 return jaxpr, out_type, treedef, dynamic_sig
File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/Workspace/catalyst-clone/frontend/catalyst/jax_tracer.py:445, in trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs)
440 with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
441 make_jaxpr_kwargs = {
442 "static_argnums": static_argnums,
443 "abstracted_axes": abstracted_axes,
444 }
--> 445 jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
447 return jaxpr, out_type, out_treedef
File ~/Workspace/catalyst-clone/frontend/catalyst/jax_extras/tracing.py:586, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
584 f, out_tree_promise = flatten_fun(f, in_tree)
585 f = annotate(f, in_type)
--> 586 jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
587 closed_jaxpr = ClosedJaxpr(jaxpr, consts)
588 return closed_jaxpr, out_type, out_tree_promise()
File ~/.local/lib/python3.10/site-packages/jax/_src/profiler.py:336, in annotate_function.<locals>.wrapper(*args, **kwargs)
333 @wraps(func)
334 def wrapper(*args, **kwargs):
335 with TraceAnnotation(name, **decorator_kwargs):
--> 336 return func(*args, **kwargs)
337 return wrapper
File ~/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2324, in trace_to_jaxpr_dynamic2(fun, debug_info)
2322 with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
2323 main.jaxpr_stack = () # type: ignore
-> 2324 jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
2325 del main, fun
2326 return jaxpr, out_type, consts
File ~/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2339, in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
2337 in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
2338 in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 2339 ans = fun.call_wrapped(*in_tracers_)
2340 out_tracers = map(trace.full_raise, ans)
2341 jaxpr, out_type, consts = frame.to_jaxpr2(out_tracers)
File ~/.local/lib/python3.10/site-packages/jax/_src/linear_util.py:191, in WrappedFun.call_wrapped(self, *args, **kwargs)
188 gen = gen_static_args = out_store = None
190 try:
--> 191 ans = self.f(*args, **dict(self.params, **kwargs))
192 except:
193 # Some transformations yield from inside context managers, so we have to
194 # interrupt them before reraising the exception. Otherwise they will only
195 # get garbage-collected at some later time, running their cleanup tasks
196 # only after this exception is handled, which can corrupt the global
197 # state.
198 while stack:
TypeError: fn() missing 2 required positional arguments: 'x' and 'y'