PyBaMM
PyBaMM copied to clipboard
Add support for MLIR-based expression evaluation
Description
Provide an MLIR-based backend for expression evaluation that can be utilised by SUNDIALS for either CPU or GPU-accelerated solves.
Motivation
At present the PyBaMM IDAKLU solver makes use of SUNDIALS which provides support for GPU, but this support cannot be harnessed as IDAKLUs method of function evaluation relies on casadi, which would require user-side compilation of PyBaMM models. Of PyBaMMs solvers, currently only JAX (a python library) translates naturally to GPU. By providing an MLIR-based backend, the code required for expression evaluation can be lowered to CPU or GPU devices, allowing expression-trees to be passed for evaluation at runtime. There have been previous efforts to expose GPU support for SUNDIALS in PyBaMM [e.g. #2644 ], which stalled as the model equations could not be easily lowered to device (GPU) code.
Possible Implementation
MLIR is a multi-level process of lowering into LLVM's Intermediate Representation (IR), which can then be cross-compiled onto multiple platforms, including GPU. This is also the (broad) mechanism for how JAX supports multiple platforms.
Additional context
https://github.com/pybamm-team/PyBaMM/issues/3766#issuecomment-1910627686
After discussion with @jsbrittain, could possibly reuse the jax backend to do jax->hlo->xla->execute in C, using the c api for the xla compiler (https://github.com/openxla/xla/tree/main/xla/examples/axpy). Xla uses mlir anyway I think!
Problem is always sparsity :) openxla is planning to add but not there yet (https://github.com/openxla/stablehlo/blob/main/rfcs/20230210-sparsity.md)
Jax has sparse matrix support, so wonder what sort of hlo they write out for that?
Jax to hlo example https://jax.readthedocs.io/en/latest/aot.html