xla
xla copied to clipboard
Generalize host callback support in JAX and IFRT
Generalize host callback support in JAX and IFRT
This change introduces a general host callback support in IFRT and changes JAX to use this interface.
- General host callback in IFRT:
xla::ifrt::LoadedHostCallbackrepresents an opaque reference to a loaded (ready-to-execute) host callback.
-
xla::ifrt::LoadedHostCallbackis created by usingxla::ifrt::CompilerMakeLoadedHostCallbackmethod. Note that this method takes an opaqueLoadedHostCallbackMakeArgswhose details can be customized based on the host callback type and the IFRT implementation. -
xla::ifrt::LoadedHostCallbackis given toxla::ifrt::CompilerCompilemethod for compilingxla::ifrt::LoadedExecutable. -
When executing
xla::ifrt::LoadedExecutable, the IFRT implementation will ensure will apply implementation-specific work to use the host callback referenced byxla:ifrt::LoadedHostCallbackduring the execution of the main computation. -
The
xla:ifrt::LoadedHostCallbackwill outlive thexla::ifrt::LoadedExecutable.
- JAX uses this new host callback interface. Same as before, JAX supports two distinct Python callback types:
-
JAX
xla::CpuCallback-based host callbacks. The memory address ofxla::CpuCallback("descriptor") is directly given to 'xla_python_cpu_callback' or 'xla_python_gpu_callback' CustomCall HLOs. This is used for CPU and GPU backends. -
PjRt
xla::HostCallback-based host callbacks. It uses host send/recv channel ids to specific input/output channels to the main XLA computation with host send/recv HLOs. This is used for TPU backends.
To overcome the conflicit that JAX uses Python functions and the IFRT API and
the PjRt-IFRT implementation are in C++, a C++ callback that can create these
callbacks is constructed in the JAX Python binding, and is handed over to
PjRt-IFRT. This enables the PjRt-IFRT to (effectively) wrap Python functions
into xla:ifrt::LoadedHostCallback without depending on pybind11 or other
Python libraries, keeping PjRt-IFRT purely C++.
The host callback API does not include general API tests because the details of host callbacks are highly specific to IFRT implementations.