PyBaMM
PyBaMM copied to clipboard
use an iterative solver (gmres) for jax_bdf_solver
Once #1235 is done, the main slow-down for the jax_bdf_solver for the DFN model is the linear solver for the jacobian. It is currently using LU decomposition of the dense jacobian, which is O(n^3). Jax is about to merge a new gmres solver (https://github.com/google/jax/pull/4832), so we can use this in the jax_bdf_solver.
Is it still the case that compiling takes much much longer than solving?
yup, compiling is very slow. My thoughts on this solver is that it is mainly good for parameter estimation or inference, where you need to call the solver many many times, so the initial compile time is not relevant.
Yep, makes sense. Is jax going to allow sparse jacobians?
no, jax doesn't directly support sparse matrices. Instead, they implement vjp's, or vector-Jacobian-products, based on the rhs that you supply. These vjps are calculated automatically via automatic differentiation. So my approach in #1235 was to manually implement a jax sparse matrix-vector product, and then Jax can calculate the vjp from that, which then is used by the iterative solver referenced in this issue.
I'm keen to get jax_bdf_solver working well with DFN. I've found 10x speedups (once you compile it!) over the casadi solver for SPM and SPMe, and other small problems, but at the moment DFN is waaaaaaay slower simply cause the number of states is much larger, so the O(n^3) starts to dominate.
Yeah would be great to get it working for DFN!
Note I recently fixed some overheads in the casadi solver for SPM. There is now solver.integration_time which tells you how long just the integration step took. The integration step for SPM is ~600us. What is it for Jax?
nice :) is this in the develop branch, or elsewhere? Do you have a benchmark script handy that I can base off of?
It's in the develop branch now, but I don't have a benchmark script. Just been running the DFN.py example and printing solution.solve_time and solution.integration_time, replacing the model with SPM or SPMe
Here are the timings for issue-1235-jax-sparse-mat-vec, which is branched off develop. I just modified the DFN.py script as you said, and took the times from the 2nd solve so that compilation time for jax isn't taken into account:
SPM (no events) with casadi solver: solve time: 0.001014654990285635 solution time: 0.000488329998916015
SPM (no events) with jax solver (runge-kutta): solve time: 0.0007341720047406852 solution time: 8.744200749788433e-05
SPM (no events) with jax solver (bdf): solve time: 0.0007417150045512244 solution time: 8.520600385963917e-05
SPMe (no events) with casadi: solve time: 0.0035160900006303564 solution time: 0.0022497270110761747
SPMe (no events) with jax (BDF): solve time: 0.0030481529975077137 solution time: 0.00021930199000053108
UPDATE: timings updated as per commit 39b9fca4b
Very nice!
note to self, also need to change mass matrix input to a function that calculates the action of the mass matrix on a vector, to make the solver entirely matrix-free