Enzyme-JAX
Enzyme-JAX copied to clipboard
Lower MPI to `stablehlo.custom_call` ops
notes
- MPI types are not ABI stable, but fortunately there is MPItrampoline which provides some wrappers over them
-
MPIABI_Comm -
MPIABI_Request -
MPIABI_Op
-
- docs on how to register and use a
custom_call:- https://jax.readthedocs.io/en/latest/ffi.html#c-interface
- https://openxla.org/xla/custom_call
questions
- should we build libmpitrampoline into it or should we link on runtime?
- how do we register
stablehlo.custom_calls? do we register them on some initialization phase in Reactant.jl or here?