PyBaMM icon indicating copy to clipboard operation
PyBaMM copied to clipboard

Add support for MLIR-based expression evaluation

Open jsbrittain opened this issue 1 year ago • 3 comments

Description

Provide an MLIR-based backend for expression evaluation that can be utilised by SUNDIALS for either CPU or GPU-accelerated solves.

Motivation

At present the PyBaMM IDAKLU solver makes use of SUNDIALS which provides support for GPU, but this support cannot be harnessed as IDAKLUs method of function evaluation relies on casadi, which would require user-side compilation of PyBaMM models. Of PyBaMMs solvers, currently only JAX (a python library) translates naturally to GPU. By providing an MLIR-based backend, the code required for expression evaluation can be lowered to CPU or GPU devices, allowing expression-trees to be passed for evaluation at runtime. There have been previous efforts to expose GPU support for SUNDIALS in PyBaMM [e.g. #2644 ], which stalled as the model equations could not be easily lowered to device (GPU) code.

Possible Implementation

MLIR is a multi-level process of lowering into LLVM's Intermediate Representation (IR), which can then be cross-compiled onto multiple platforms, including GPU. This is also the (broad) mechanism for how JAX supports multiple platforms.

Additional context

https://github.com/pybamm-team/PyBaMM/issues/3766#issuecomment-1910627686

jsbrittain avatar Feb 19 '24 12:02 jsbrittain

After discussion with @jsbrittain, could possibly reuse the jax backend to do jax->hlo->xla->execute in C, using the c api for the xla compiler (https://github.com/openxla/xla/tree/main/xla/examples/axpy). Xla uses mlir anyway I think!

martinjrobins avatar Feb 24 '24 10:02 martinjrobins

Problem is always sparsity :) openxla is planning to add but not there yet (https://github.com/openxla/stablehlo/blob/main/rfcs/20230210-sparsity.md)

Jax has sparse matrix support, so wonder what sort of hlo they write out for that?

martinjrobins avatar Feb 24 '24 10:02 martinjrobins

Jax to hlo example https://jax.readthedocs.io/en/latest/aot.html

martinjrobins avatar Feb 24 '24 15:02 martinjrobins