xla icon indicating copy to clipboard operation
xla copied to clipboard

Generalize host callback support in JAX and IFRT

Open copybara-service[bot] opened this issue 2 years ago • 0 comments

Generalize host callback support in JAX and IFRT

This change introduces a general host callback support in IFRT and changes JAX to use this interface.

  • General host callback in IFRT: xla::ifrt::LoadedHostCallback represents an opaque reference to a loaded (ready-to-execute) host callback.
  1. xla::ifrt::LoadedHostCallback is created by using xla::ifrt::Compiler MakeLoadedHostCallback method. Note that this method takes an opaque LoadedHostCallbackMakeArgs whose details can be customized based on the host callback type and the IFRT implementation.

  2. xla::ifrt::LoadedHostCallback is given to xla::ifrt::Compiler Compile method for compiling xla::ifrt::LoadedExecutable.

  3. When executing xla::ifrt::LoadedExecutable, the IFRT implementation will ensure will apply implementation-specific work to use the host callback referenced by xla:ifrt::LoadedHostCallback during the execution of the main computation.

  4. The xla:ifrt::LoadedHostCallback will outlive the xla::ifrt::LoadedExecutable.

  • JAX uses this new host callback interface. Same as before, JAX supports two distinct Python callback types:
  1. JAX xla::CpuCallback-based host callbacks. The memory address of xla::CpuCallback ("descriptor") is directly given to 'xla_python_cpu_callback' or 'xla_python_gpu_callback' CustomCall HLOs. This is used for CPU and GPU backends.

  2. PjRt xla::HostCallback-based host callbacks. It uses host send/recv channel ids to specific input/output channels to the main XLA computation with host send/recv HLOs. This is used for TPU backends.

To overcome the conflicit that JAX uses Python functions and the IFRT API and the PjRt-IFRT implementation are in C++, a C++ callback that can create these callbacks is constructed in the JAX Python binding, and is handed over to PjRt-IFRT. This enables the PjRt-IFRT to (effectively) wrap Python functions into xla:ifrt::LoadedHostCallback without depending on pybind11 or other Python libraries, keeping PjRt-IFRT purely C++.

The host callback API does not include general API tests because the details of host callbacks are highly specific to IFRT implementations.

copybara-service[bot] avatar Feb 22 '23 23:02 copybara-service[bot]