PyBaMM
PyBaMM copied to clipboard
[Bug]: IDAKLU Jax interface reorders inputs
PyBaMM Version
25.1.1
Python Version
3.12
Describe the bug
Use of tree_flatten in the function wrapper for the JAX primitive function (f(t, inputs) in the IDAKLUJax class) reorders inputs leading to incorrect solutions.
Steps to Reproduce
import pybamm
from solver_wrapper import IDAKLUJax
import os
os.environ["JAX_TRACEBACK_FILTERING"]="off"
pybamm.set_logging_level("DEBUG")
param = pybamm.ParameterValues("OKane2022")
model = pybamm.lithium_ion.SPMe()
geometry = model.default_geometry
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: 20, var.x_s: 20, var.x_p: 20, var.r_n: 10, var.r_p: 10}
submesh_types = model.default_submesh_types
t_eval = np.linspace(0, 3600, 3600)
output_variables = ["Voltage [V]"]
inputs = {
"Separator thickness [m]": param["Separator thickness [m]"],
"Separator porosity": param["Separator porosity"],
"Positive electrode conductivity [S.m-1]": param[
"Positive electrode conductivity [S.m-1]"
],
"Negative electrode conductivity [S.m-1]": param[
"Negative electrode conductivity [S.m-1]"
],
"Current function [A]": 5.0,
}
param.update(inputs)
param.process_geometry(geometry)
param.update({key: "[input]" for key in inputs.keys()})
param.process_model(model, inplace=True)
mesh = pybamm.Mesh(geometry, submesh_types, var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model, inplace=True)
solver = pybamm.IDAKLUSolver(
rtol=1e-6,
atol=1e-6,
output_variables=output_variables,
)
jax_solver = IDAKLUJax(
solver,
model,
t_eval,
output_variables=output_variables,
calculate_sensitivities=True,
t_interp=None,
)
f = jax_solver.get_jaxpr()
data = f(t_eval, inputs)
voltage = jax_solver.get_var("Voltage [V]")(t_eval, inputs)
Relevant log output
@BradyPlanden Could you take a look at this?
Yep, I can take a look at this.