Add an `ffi_call` function with a similar signature to `pure_callback`
Currently, JAX users who want to use XLA custom calls must interact with private APIs (e.g. core.Primitive) and MLIR. This doesn’t provide a great developer experience, and it would be useful to provide some sort of public API. This has been previously discussed in several different contexts (including #12632), and this PR builds on ideas from this previous work to present a simple API that covers some core use cases.
There are more advanced use cases which would require finer-grained customization, and these would continue to rely on the private API. But, there do appear to be examples of use cases that would be satisfied by this simpler interface.
Example
The general idea is to provide a function called (something like) jax.extend.ffi.ffi_call with a signature that is similar to jax.pure_callback, that lowers to a custom call. For example, the existing implementation of lu_pivots_to_permutation on GPU (the only FFI custom call currently in jaxlib), could (to first approximation) be written as:
def ffi_call_lu_pivots_to_permutation(pivots, permutation_size):
return jex.ffi.ffi_call(
"cu_lu_pivots_to_permutation",
# Output types are specified without reference to MLIR
jax.ShapeDtypeStruct(
shape=dims[:-1] + (permutation_size,),
dtype=pivots.dtype,
),
# Input arguments
pivots,
# Keyword arguments are passed to the FFI custom call as attributes
permutation_size=np.int32(permutation_size), # Note: np not jnp
)
from jax.lib import xla_client
xla_client.register_custom_call_target(
"cu_lu_pivots_to_permutation", ..., platform="CUDA", api_version=1)
Platform-dependent behavior should be handled in user code with the help of lax.platform_dependent. (Currently this doesn't work, but @gnecula is looking into it.) Like jax.pure_callback, this could be combined with custom_jvp or custom_vjp to support autodiff. vmap is discussed below.
Batching
This proof-of-concept implementation includes a vectorized parameter which has the same behavior as the vectorized parameter to jax.pure_callback (in fact it uses exactly the same batching rule). The tl;dr is that when vectorized is False, the base custom call is executed in a while loop, but when vectorized is True, the vmapped primitive calls the same custom call with an extra batch dimension on the left. This behavior could potentially work with the FFI interface since the input buffers include dimension metadata, but it’s a restrictive interface. Is there a better approach (don’t say custom_vmap! Or do...)?
Alternatives
If we’re not totally wedded to aligning with the jax.pure_callback interface, it’s possible that a "builder" interface would be more future proof. For example, the syntax for the demo from above would be something like:
do_call = jex.ffi.make_ffi_call("cu_lu_pivots_to_permutation")
do_call(
jax.ShapeDtypeStruct(
shape=dims[:-1] + (permutation_size,),
dtype=pivots.dtype,
),
pivots,
batch_size=np.int64(batch_size),
pivot_size=np.int32(pivot_size),
permutation_size=np.int32(permutation_size),
)
This has the potential benefit that do_call could include metadata like a reference to the underlying core.Primitive so that users could use that for further customization.