xla icon indicating copy to clipboard operation
xla copied to clipboard

How to add a xla customcall op ?

Open dinghaodhd opened this issue 1 year ago • 3 comments

Hi, I am trying to add xla customcall op for third-party hardware, but i found that the customcall process is different to tf/xla, for example, i can not find REGISTER_XLA_OP in openxla or torchxla, i do not know what to do next, can you help me?

dinghaodhd avatar Jan 10 '24 02:01 dinghaodhd

Hi,

There are 4 different custom call API versions within XLA. And for the first two versions, the API / expected function signature is different between CPU and GPU backends: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/service/hlo.proto#L51-L111

API v1 - v3 all gets inputs as void** inputs and output as void* output. Programmers have to manually re-interpret the type of each input argument, e.g., inputs[0] may be a pointer to a float array, inputs[1] may be a pointer to a size_t storing the length of the first float array, etc.

API v4 (typed FFI) is the latest and most programmable version. It supports calls to functions with typed arguments (with type and size metadata). However, it is still under development. It is quite usable in the GPU backend now, but we haven't added support for the CPU backend yet. If you are adding a GPU op, I'd encourage using API v4. -- When all the v4 work is done for all backends, we plan to migrate all custom calls to v4 and deprecate older versions.

You can see examples in custom_call_test.cc for both backends (only the GPU tests have API v4 examples):

GPU backend

API v1 tests:

https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/service/gpu/custom_call_test.cc#L81-L305

API v2&3 test:

https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/service/gpu/custom_call_test.cc#L307-L347

API v4 test:

https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/service/gpu/custom_call_test.cc#L349-L669

CPU backend

API v1 tests:

Custom functions: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L41-L66

Custom call target registrations: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L90-L93

Corresponding tests: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L109-L242

API v2 test:

Custom functions: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L68-L77

Custom call target registrations: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L94-L95

Corresponding tests: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L244-L322

API v3 test:

Custom function: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L79-L86

Custom call target registration: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L96

Corresponding test: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/tests/custom_call_test.cc#L324-L346

penpornk avatar Feb 29 '24 10:02 penpornk

There are more examples outside the OpenXLA repo, e.g., in JAX.

JAX examples

JAX has many XLA custom calls for linear algebra routines in both GPU and CPU backends.

GPU ops

Custom call registrations are in jaxlib/gpu/gpu_kernels.cc. Custom functions are implemented in jaxlib/gpu/*.cc.

Example: The 1st registration binds the function GetrfBatched to a custom call named "cublas_getrf_batched".

XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched,
                                         "CUDA");

GetrfBatched in blas_kernels.cc uses API v3

void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque,
                  size_t opaque_len, XlaCustomCallStatus* status) {

CPU ops

Custom call registrations are in jaxlib/cpu/cpu_kernels.cc. Custom functions are implemented in jaxlib/cpu/*.cc.

Example: The 1st registration binds Trsm<float>::Kernel to a custom call named "blas_strsm".

XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_strsm", Trsm<float>::Kernel,
                                         "Host");

Trsm<T>::Kernel in lapack_kernels.cc uses API v2.

void Trsm<T>::Kernel(void* out, void** data, XlaCustomCallStatus*) {

penpornk avatar Feb 29 '24 10:02 penpornk

Is it possible to add some out-of-tree cases? It works naturally in in-tree building, but becomes complex in out-of-tree since some needed declarations are still in OpenXLA itself (FFI namespace, ServiceExecutableRunOptions, etc.). A possible solution is to import these dependencies as a new sub-modular of OpenXLA.

Zantares avatar May 14 '24 09:05 Zantares