Expand the `ffi_call` API to support more use cases
The high level change in this PR is to update ffi_call to return a callable. Before this change, the ffi_call syntax was:
ffi_call("target_name", output_type, *input_args, vmap_method="...", **input_kwargs)
but, after this change, this syntax is deprecated and replaced with:
ffi_call("target_name", output_type, vmap_method="...")(*input_args, **input_kwargs)
My proposal is that the old syntax should continue to work for 6 months, even though jax.extend doesn't formally need to support the full deprecation cycle, because it's the polite thing to do.
Motivation
This change is motivated by an effort to replace more uses of mlir.custom_call that occur in the wild. There are currently several key features that are not supported by ffi_call: (1) fine-grained control over lowering (input/output aliases, memory layouts, ...), and (2) full customization of behavior under transforms. This PR is the first step in addressing (1).
The main point is that I want to add arguments like input_output_aliases, input_layouts, api_version, etc. (specific design still under consideration) to ffi_call. While these arguments could be supported by the old API, e.g.:
ffi_call("target_name", output_type, *input_args, vmap_method="...",
**input_kwargs, input_output_aliases={0: 0})
this introduces some potential issues:
-
This is starting to get somewhat hard to read because it's not totally clear which arguments are controlling the behavior of the call vs. being passed to the FFI as attributes. This was already a problem with
vmap_method, but as we add more controller parameters, this problem gets worse. -
It's not straightforward to raise errors for unexpected arguments. This came up in https://github.com/jax-ml/jax/issues/24131 where an unrecognized argument (the new
vmap_methodargument) ended up passed to the FFI handler as an attribute, raising a confusing error. Separating the call site for the control arguments from the attributes makes it more straightforward to raise the appropriate error message.