PyBaMM icon indicating copy to clipboard operation
PyBaMM copied to clipboard

Add support for MLIR-based expression evaluation

Open jsbrittain opened this issue 1 year ago • 5 comments

Description

Add a new expression evaluation backend to the IDAKLU solver. MLIR expression evaluation is now supported by lowering PyBaMM's Jax-based expressions into MLIR, which are then compiled and executed as part of the IDAKLU solver using IREE.

To enable the IREE/MLIR backend, set the (new) PYBAMM_IDAKLU_EXPR_IREE compiler flag ON via an environment variable and install PyBaMM using the developer method (by default PYBAMM_IDAKLU_EXPR_IREE is turned OFF):

export PYBAMM_IDAKLU_EXPR_IREE=ON
nox -e pybamm-requires && nox -e dev

Expression evaluation in IDAKLU is enabled by constructing the model using Jax expressions (model.convert_to_format="jax") and setting the solver backend (jax_evaluator="iree"). Example:

import pybamm
import numpy as np

model = pybamm.lithium_ion.SPM()
model.convert_to_format = "jax"
geometry = model.default_geometry
param = model.default_parameter_values
param.process_model(model)
param.process_geometry(model.default_geometry)
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

solver = pybamm.IDAKLUSolver(
	root_method="hybr",  # change from default ("casadi")
	options={"jax_evaluator": "iree"}
)
solution = solver.solve(model, np.linspace(0, 3600, 2500))

print(solution["Voltage [V]"].entries[:100])

Note that IREE currently only supports single-precision floating-point operations, which requires the model to be demoted from 64-bit to 32-bit precision before the solver can run. This is handled within the solver logic, but the operation is performed in-place on the PyBaMM battery model (we display a warning when run). Operating at lower precision requires tolerances to be relaxed for convergence on larger [e.g. DFN] models, and leads to memory transfers and type casting in the solver which are currently causing slow-downs (at least until 64-bit computation is natively supported).

Comparative performance on the above SPM problem on an Apple M2 Macbook Pro (with events=[] to allow comparison to the JaxSolver):

  • IDAKLU-IREE took 1.4 secs (1.3 secs to demote and compile the expressions; <0.1 secs for each subsequent solve).
  • IDAKLU-Casadi: took 0.14 secs (<0.1 secs setup; <0.1 secs for each subsequent solve).
  • JaxSolver [BDF]: took 1.0 secs (0.9 secs compilation; 0.1 secs for each subsequent solve).

Substituting a DFN model (and reducing atol = 1e-1) the times become:

  • IDAKLU-IREE took 12.1 secs (8.7 secs to demote and compile the expressions; 3.4 secs for each subsequent solve).
  • IDAKLU-Casadi: took 0.6 secs (0.3 secs setup; 0.3 secs for each subsequent solve).
  • JaxSolver [BDF]: took 22.8 secs (7.6 secs compilation; 15.2 secs for each subsequent solve).

There is a noticeable performance deficit for the IDAKLU-MLIR solver compared to Casadi, due to 1) initial compilation of MLIR to bytecode, 2) demotion strategies, and 3) memory transfers casting between types in the solver. We anticipate improvements in the second and third points with native 64-bit IREE support, and as our IREE approach compiles on the model expressions (not the solver) compilation times quickly out-perform the JaxSolver with increasing model complexity / time steps (while also taking full advantage of the capabilities already provided by the IDAKLU solver, such as events). The IREE/MLIR approach offers a pathway to compiling expressions across a wide variety of backends, including metal and cuda, although additional code adjustment (principally host/device transfers) will be required before those can be supported.

Resolves #3826

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • [X] New feature (non-breaking change which adds functionality)
  • [ ] Optimization (back-end change that speeds up the code)
  • [ ] Bug fix (non-breaking change which fixes an issue)

Key checklist:

  • [X] No style issues: $ pre-commit run (or $ nox -s pre-commit) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • [X] All tests pass: $ python run-tests.py --all (or $ nox -s tests)
  • [X] The documentation builds: $ python run-tests.py --doctest (or $ nox -s doctests)

You can run integration tests, unit tests, and doctests together at once, using $ python run-tests.py --quick (or $ nox -s quick).

Further checks:

  • [X] Code is commented, particularly in hard-to-understand areas
  • [X] Tests added that prove fix is effective or that feature works

jsbrittain avatar Jun 19 '24 11:06 jsbrittain

Note: Codacy seems to be struggling with the template / inheritance structures in C++, hence the inclusion of additional // cppcheck-suppress comments in the IDAKLU solver.

jsbrittain avatar Jun 19 '24 13:06 jsbrittain

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 99.56%. Comparing base (ab7348f) to head (d574d8c). Report is 222 commits behind head on develop.

Additional details and impacted files
@@            Coverage Diff            @@
##           develop    #4199    +/-   ##
=========================================
  Coverage    99.55%   99.56%            
=========================================
  Files          288      288            
  Lines        21897    22086   +189     
=========================================
+ Hits         21800    21989   +189     
  Misses          97       97            

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Jun 19 '24 15:06 codecov[bot]

ccing @cringeyburger here because this PR will be adding and modifying a lot of compiled code, we shall need to make a few adjustments in the migration to scikit-build-core as needed. Though, as long as the wheels builds pass (@jsbrittain, could you please trigger them on your fork?), it should be fine.

@agriyakhetarpal Wheels build fine: https://github.com/jsbrittain/PyBaMM/actions/runs/9665563143

jsbrittain avatar Jun 25 '24 17:06 jsbrittain

thanks @jsbrittain this looks excellent. I've made a few suggestions below, see what you think. also, is jax_evaluator="iree" needed for the options, it there a case where you want to convert to jax but not use the jax_evaluator?

@martinjrobins yes, there is actually an existing python-idaklu interface that will run if we don't redirect using the (new) jax_evaluator option. I think it's a legacy item (idaklu/python.cpp) (it can be quite slow, even on these toy examples).

jsbrittain avatar Jun 27 '24 13:06 jsbrittain

that reminds me, we should get rid of python-idaklu, I don't think it serves a useful purpose anymore. I'll add an issue

martinjrobins avatar Jun 27 '24 14:06 martinjrobins

Hi again, @jsbrittain – https://github.com/pybamm-team/PyBaMM/pull/4205 is trying to migrate PyBaMM's package structure to an src layout from the current flat one, which will move all the files from pybamm/ into src/pybamm/. It has a lot of potential for delays because of inevitable merge conflicts across several PRs, so we plan to merge it as soon as stable v24.5 hits next week, just a heads up :)

agriyakhetarpal avatar Jul 12 '24 13:07 agriyakhetarpal