PyBaMM icon indicating copy to clipboard operation
PyBaMM copied to clipboard

refactor multiprocessing and multiple inputs

Open martinjrobins opened this issue 9 months ago • 7 comments

Description

Refactor the current multiprocessing implementation to push this responsibility onto the solver.

Related issues #3987, #3713, #2645, #2644, #3910

Motivation

At the moment users can request multiple parallel solves by passing in a list of inputs

solver.solve(t_eval, inputs = [ {'a': 1}, {'a': 2} ])

For all solvers apart from the Jax solver, this uses the python multiprocessing library to run these solves in parallel. However, this causes issues on Windows #3987, and this could probably be much more efficient if done at the solver level rather than up in Python.

Possible Implementation

I would propose that if a list of inputs of size n is detected, the model equations are duplicated n times before being passed to the solver. All the solvers we use have in-built parallel functionality that can be utilised to run this in parallel, and once we have a decent gpu solver this could be done on the gpu. After the solve is finished then the larger state vector will be split apart and passed to multiple Solution objects.

Additional context

At the moment this snippet captures how multiprocessing is implemented:

            ninputs = len(model_inputs_list)
            if ninputs == 1:
                new_solution = self._integrate(
                    model,
                    t_eval[start_index:end_index],
                    model_inputs_list[0],
                )
                new_solutions = [new_solution]
            else:
                if model.convert_to_format == "jax":
                    # Jax can parallelize over the inputs efficiently
                    new_solutions = self._integrate(
                        model,
                        t_eval[start_index:end_index],
                        model_inputs_list,
                    )
                else:
                    with mp.get_context(self._mp_context).Pool(processes=nproc) as p:
                        new_solutions = p.starmap(
                            self._integrate,
                            zip(
                                [model] * ninputs,
                                [t_eval[start_index:end_index]] * ninputs,
                                model_inputs_list,
                            ),
                        )
                        p.close()
                        p.join()

martinjrobins avatar May 14 '24 11:05 martinjrobins