PyBaMM icon indicating copy to clipboard operation
PyBaMM copied to clipboard

Add support for MLIR-based expression evaluation

Open jsbrittain opened this issue 8 months ago • 5 comments

Description

Add a new expression evaluation backend to the IDAKLU solver. MLIR expression evaluation is now supported by lowering PyBaMM's Jax-based expressions into MLIR, which are then compiled and executed as part of the IDAKLU solver using IREE.

To enable the IREE/MLIR backend, set the (new) PYBAMM_IDAKLU_EXPR_IREE compiler flag ON via an environment variable and install PyBaMM using the developer method (by default PYBAMM_IDAKLU_EXPR_IREE is turned OFF):

export PYBAMM_IDAKLU_EXPR_IREE=ON
nox -e pybamm-requires && nox -e dev

Expression evaluation in IDAKLU is enabled by constructing the model using Jax expressions (model.convert_to_format="jax") and setting the solver backend (jax_evaluator="iree"). Example:

import pybamm
import numpy as np

model = pybamm.lithium_ion.SPM()
model.convert_to_format = "jax"
geometry = model.default_geometry
param = model.default_parameter_values
param.process_model(model)
param.process_geometry(model.default_geometry)
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

solver = pybamm.IDAKLUSolver(
	root_method="hybr",  # change from default ("casadi")
	options={"jax_evaluator": "iree"}
)
solution = solver.solve(model, np.linspace(0, 3600, 2500))

print(solution["Voltage [V]"].entries[:100])

Note that IREE currently only supports single-precision floating-point operations, which requires the model to be demoted from 64-bit to 32-bit precision before the solver can run. This is handled within the solver logic, but the operation is performed in-place on the PyBaMM battery model (we display a warning when run). Operating at lower precision requires tolerances to be relaxed for convergence on larger [e.g. DFN] models, and leads to memory transfers and type casting in the solver which are currently causing slow-downs (at least until 64-bit computation is natively supported).

Comparative performance on the above SPM problem on an Apple M2 Macbook Pro (with events=[] to allow comparison to the JaxSolver):

  • IDAKLU-IREE took 1.4 secs (1.3 secs to demote and compile the expressions; <0.1 secs for each subsequent solve).
  • IDAKLU-Casadi: took 0.14 secs (<0.1 secs setup; <0.1 secs for each subsequent solve).
  • JaxSolver [BDF]: took 1.0 secs (0.9 secs compilation; 0.1 secs for each subsequent solve).

Substituting a DFN model (and reducing atol = 1e-1) the times become:

  • IDAKLU-IREE took 12.1 secs (8.7 secs to demote and compile the expressions; 3.4 secs for each subsequent solve).
  • IDAKLU-Casadi: took 0.6 secs (0.3 secs setup; 0.3 secs for each subsequent solve).
  • JaxSolver [BDF]: took 22.8 secs (7.6 secs compilation; 15.2 secs for each subsequent solve).

There is a noticeable performance deficit for the IDAKLU-MLIR solver compared to Casadi, due to 1) initial compilation of MLIR to bytecode, 2) demotion strategies, and 3) memory transfers casting between types in the solver. We anticipate improvements in the second and third points with native 64-bit IREE support, and as our IREE approach compiles on the model expressions (not the solver) compilation times quickly out-perform the JaxSolver with increasing model complexity / time steps (while also taking full advantage of the capabilities already provided by the IDAKLU solver, such as events). The IREE/MLIR approach offers a pathway to compiling expressions across a wide variety of backends, including metal and cuda, although additional code adjustment (principally host/device transfers) will be required before those can be supported.

Resolves #3826

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • [X] New feature (non-breaking change which adds functionality)
  • [ ] Optimization (back-end change that speeds up the code)
  • [ ] Bug fix (non-breaking change which fixes an issue)

Key checklist:

  • [X] No style issues: $ pre-commit run (or $ nox -s pre-commit) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • [X] All tests pass: $ python run-tests.py --all (or $ nox -s tests)
  • [X] The documentation builds: $ python run-tests.py --doctest (or $ nox -s doctests)

You can run integration tests, unit tests, and doctests together at once, using $ python run-tests.py --quick (or $ nox -s quick).

Further checks:

  • [X] Code is commented, particularly in hard-to-understand areas
  • [X] Tests added that prove fix is effective or that feature works

jsbrittain avatar Jun 19 '24 11:06 jsbrittain