iree
iree copied to clipboard
Determine how to support JAX's callback system
JAX supports registering callbacks that can be triggered via the backend environment. This may require implementation
via the mhlo.custom_call system or an alternative integration method via the PJRT events system.
Or via something like https://github.com/openxla/iree/tree/main/samples/py_custom_module if this is callback to Python.
Is it this:https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html ?
Can this issue be closed