xla
xla copied to clipboard
How to add a xla customcall op ?
Hi, I am trying to add xla customcall op for third-party hardware, but i found that the customcall process is different to tf/xla, for example, i can not find REGISTER_XLA_OP in openxla or torchxla, i do not know what to do next, can you help me?
Hi,
There are 4 different custom call API versions within XLA. And for the first two versions, the API / expected function signature is different between CPU and GPU backends: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/service/hlo.proto#L51-L111
API v1 - v3 all gets inputs as void** inputs and output as void* output. Programmers have to manually re-interpret the type of each input argument, e.g., inputs[0] may be a pointer to a float array, inputs[1] may be a pointer to a size_t storing the length of the first float array, etc.
API v4 (typed FFI) is the latest and most programmable version. It supports calls to functions with typed arguments (with type and size metadata). However, it is still under development. It is quite usable in the GPU backend now, but we haven't added support for the CPU backend yet. If you are adding a GPU op, I'd encourage using API v4. -- When all the v4 work is done for all backends, we plan to migrate all custom calls to v4 and deprecate older versions.
You can see examples in custom_call_test.cc for both backends (only the GPU tests have API v4 examples):
GPU backend
API v1 tests:
https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/service/gpu/custom_call_test.cc#L81-L305
API v2&3 test:
https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/service/gpu/custom_call_test.cc#L307-L347
API v4 test:
https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/service/gpu/custom_call_test.cc#L349-L669
CPU backend
API v1 tests:
Custom functions: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L41-L66
Custom call target registrations: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L90-L93
Corresponding tests: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L109-L242
API v2 test:
Custom functions: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L68-L77
Custom call target registrations: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L94-L95
Corresponding tests: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L244-L322
API v3 test:
Custom function: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L79-L86
Custom call target registration: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L96
Corresponding test: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L324-L346
There are more examples outside the OpenXLA repo, e.g., in JAX.
JAX examples
JAX has many XLA custom calls for linear algebra routines in both GPU and CPU backends.
GPU ops
Custom call registrations are in jaxlib/gpu/gpu_kernels.cc. Custom functions are implemented in jaxlib/gpu/*.cc.
Example:
The 1st registration binds the function GetrfBatched to a custom call named "cublas_getrf_batched".
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched,
"CUDA");
GetrfBatched in blas_kernels.cc uses API v3
void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
CPU ops
Custom call registrations are in jaxlib/cpu/cpu_kernels.cc. Custom functions are implemented in jaxlib/cpu/*.cc.
Example:
The 1st registration binds Trsm<float>::Kernel to a custom call named "blas_strsm".
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_strsm", Trsm<float>::Kernel,
"Host");
Trsm<T>::Kernel in lapack_kernels.cc uses API v2.
void Trsm<T>::Kernel(void* out, void** data, XlaCustomCallStatus*) {
Is it possible to add some out-of-tree cases? It works naturally in in-tree building, but becomes complex in out-of-tree since some needed declarations are still in OpenXLA itself (FFI namespace, ServiceExecutableRunOptions, etc.). A possible solution is to import these dependencies as a new sub-modular of OpenXLA.