PyBaMM icon indicating copy to clipboard operation
PyBaMM copied to clipboard

[Bug]: IDAKLU Jax interface reorders inputs

Open anoushka2000 opened this issue 10 months ago • 2 comments

PyBaMM Version

25.1.1

Python Version

3.12

Describe the bug

Use of tree_flatten in the function wrapper for the JAX primitive function (f(t, inputs) in the IDAKLUJax class) reorders inputs leading to incorrect solutions.

Steps to Reproduce

import pybamm
from solver_wrapper import IDAKLUJax
import os
os.environ["JAX_TRACEBACK_FILTERING"]="off"
pybamm.set_logging_level("DEBUG")

param = pybamm.ParameterValues("OKane2022")

model = pybamm.lithium_ion.SPMe()
geometry = model.default_geometry

var = pybamm.standard_spatial_vars
var_pts = {var.x_n: 20, var.x_s: 20, var.x_p: 20, var.r_n: 10, var.r_p: 10}
submesh_types = model.default_submesh_types

t_eval = np.linspace(0, 3600, 3600)
output_variables = ["Voltage [V]"]

inputs = {
    "Separator thickness [m]": param["Separator thickness [m]"],
    "Separator porosity": param["Separator porosity"],
    "Positive electrode conductivity [S.m-1]": param[
        "Positive electrode conductivity [S.m-1]"
    ],
    "Negative electrode conductivity [S.m-1]": param[
        "Negative electrode conductivity [S.m-1]"
    ],
    "Current function [A]": 5.0,
}

param.update(inputs)
param.process_geometry(geometry)
param.update({key: "[input]" for key in inputs.keys()})

param.process_model(model, inplace=True)

mesh = pybamm.Mesh(geometry, submesh_types, var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)

disc.process_model(model, inplace=True)

solver = pybamm.IDAKLUSolver(
    rtol=1e-6,
    atol=1e-6,
    output_variables=output_variables,
)

jax_solver = IDAKLUJax(
            solver, 
            model,
            t_eval,
            output_variables=output_variables,
            calculate_sensitivities=True,
            t_interp=None,
        )

f = jax_solver.get_jaxpr()
data = f(t_eval, inputs)

voltage = jax_solver.get_var("Voltage [V]")(t_eval, inputs)

Relevant log output


anoushka2000 avatar Feb 14 '25 18:02 anoushka2000

@BradyPlanden Could you take a look at this?

kratman avatar Feb 14 '25 18:02 kratman

Yep, I can take a look at this.

BradyPlanden avatar Feb 20 '25 10:02 BradyPlanden