PyBaMM icon indicating copy to clipboard operation
PyBaMM copied to clipboard

use an iterative solver (gmres) for jax_bdf_solver

Open martinjrobins opened this issue 5 years ago • 10 comments

Once #1235 is done, the main slow-down for the jax_bdf_solver for the DFN model is the linear solver for the jacobian. It is currently using LU decomposition of the dense jacobian, which is O(n^3). Jax is about to merge a new gmres solver (https://github.com/google/jax/pull/4832), so we can use this in the jax_bdf_solver.

martinjrobins avatar Nov 10 '20 14:11 martinjrobins

Is it still the case that compiling takes much much longer than solving?

valentinsulzer avatar Nov 10 '20 16:11 valentinsulzer

yup, compiling is very slow. My thoughts on this solver is that it is mainly good for parameter estimation or inference, where you need to call the solver many many times, so the initial compile time is not relevant.

martinjrobins avatar Nov 10 '20 16:11 martinjrobins

Yep, makes sense. Is jax going to allow sparse jacobians?

valentinsulzer avatar Nov 10 '20 17:11 valentinsulzer

no, jax doesn't directly support sparse matrices. Instead, they implement vjp's, or vector-Jacobian-products, based on the rhs that you supply. These vjps are calculated automatically via automatic differentiation. So my approach in #1235 was to manually implement a jax sparse matrix-vector product, and then Jax can calculate the vjp from that, which then is used by the iterative solver referenced in this issue.

I'm keen to get jax_bdf_solver working well with DFN. I've found 10x speedups (once you compile it!) over the casadi solver for SPM and SPMe, and other small problems, but at the moment DFN is waaaaaaay slower simply cause the number of states is much larger, so the O(n^3) starts to dominate.

martinjrobins avatar Nov 10 '20 18:11 martinjrobins

Yeah would be great to get it working for DFN! Note I recently fixed some overheads in the casadi solver for SPM. There is now solver.integration_time which tells you how long just the integration step took. The integration step for SPM is ~600us. What is it for Jax?

valentinsulzer avatar Nov 10 '20 18:11 valentinsulzer

nice :) is this in the develop branch, or elsewhere? Do you have a benchmark script handy that I can base off of?

martinjrobins avatar Nov 10 '20 20:11 martinjrobins

It's in the develop branch now, but I don't have a benchmark script. Just been running the DFN.py example and printing solution.solve_time and solution.integration_time, replacing the model with SPM or SPMe

valentinsulzer avatar Nov 11 '20 03:11 valentinsulzer

Here are the timings for issue-1235-jax-sparse-mat-vec, which is branched off develop. I just modified the DFN.py script as you said, and took the times from the 2nd solve so that compilation time for jax isn't taken into account:

SPM (no events) with casadi solver: solve time: 0.001014654990285635 solution time: 0.000488329998916015

SPM (no events) with jax solver (runge-kutta): solve time: 0.0007341720047406852 solution time: 8.744200749788433e-05

SPM (no events) with jax solver (bdf): solve time: 0.0007417150045512244 solution time: 8.520600385963917e-05

SPMe (no events) with casadi: solve time: 0.0035160900006303564 solution time: 0.0022497270110761747

SPMe (no events) with jax (BDF): solve time: 0.0030481529975077137 solution time: 0.00021930199000053108

UPDATE: timings updated as per commit 39b9fca4b

martinjrobins avatar Nov 11 '20 06:11 martinjrobins

Very nice!

valentinsulzer avatar Nov 11 '20 13:11 valentinsulzer

note to self, also need to change mass matrix input to a function that calculates the action of the mass matrix on a vector, to make the solver entirely matrix-free

martinjrobins avatar Nov 20 '20 06:11 martinjrobins