catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

Support keyword arguments for qjit-compiled functions

Open mehrdad2m opened this issue 1 year ago • 3 comments

Context:

Currently in Catalyst, keyword arguments to functions are not supported by qjit. for example:

@qjit()
def f(x, y):
    return x * y

result = f(3, y=2)
assert result == f(2, 3) 

would result in the following exception raised

TypeError: fn() missing 2 required positional arguments: 'x' and 'y'

The main reason for this is that kwargs is no passed to jit_compile function.

Description of the Change: We pass kwargs down to jit_compile and then to capture and finally to trace_to_jaxpr functions and use it to modify the tracers. Note that when captured to jaxpr, the keyword argument format of the input argument get lost and it would be considered similar to other arguments. Therefore we need to pass the keyword arguments as a simple argument when calling the compiled function.

Benefits: User can pass keyword arguments when using qjit.

Possible Drawbacks: static_argnums is still not supported when using keyword arguments. To support this, we need to first support static_argnames to be able to filter static kwargs

Related GitHub Issues: [sc-67133]

mehrdad2m avatar Aug 07 '24 20:08 mehrdad2m