PyBaMM
PyBaMM copied to clipboard
refactor multiprocessing and multiple inputs
Description
Refactor the current multiprocessing implementation to push this responsibility onto the solver.
Related issues #3987, #3713, #2645, #2644, #3910
Motivation
At the moment users can request multiple parallel solves by passing in a list of inputs
solver.solve(t_eval, inputs = [ {'a': 1}, {'a': 2} ])
For all solvers apart from the Jax solver, this uses the python multiprocessing library to run these solves in parallel. However, this causes issues on Windows #3987, and this could probably be much more efficient if done at the solver level rather than up in Python.
Possible Implementation
I would propose that if a list of inputs of size n
is detected, the model equations are duplicated n
times before being passed to the solver. All the solvers we use have in-built parallel functionality that can be utilised to run this in parallel, and once we have a decent gpu solver this could be done on the gpu. After the solve is finished then the larger state vector will be split apart and passed to multiple Solution
objects.
Additional context
At the moment this snippet captures how multiprocessing is implemented:
ninputs = len(model_inputs_list)
if ninputs == 1:
new_solution = self._integrate(
model,
t_eval[start_index:end_index],
model_inputs_list[0],
)
new_solutions = [new_solution]
else:
if model.convert_to_format == "jax":
# Jax can parallelize over the inputs efficiently
new_solutions = self._integrate(
model,
t_eval[start_index:end_index],
model_inputs_list,
)
else:
with mp.get_context(self._mp_context).Pool(processes=nproc) as p:
new_solutions = p.starmap(
self._integrate,
zip(
[model] * ninputs,
[t_eval[start_index:end_index]] * ninputs,
model_inputs_list,
),
)
p.close()
p.join()