Cirq
Cirq copied to clipboard
Can't use cirq.Simulator() in a multiprocessing closure (unable to pickle)
When using python's multiprocessing
, the function closure must be pickleable so that it can be shipped off to the subprocesses. This means you can't use lambdas or anything. A common workaround is to define a helper closure class
class _SimulateClosure:
def __init__(self, simulator):
self.simulator = simulator
def __call__(self, circuit):
return self.simulator.simulate(circuit)
def do_it():
circuits = ...
_simulate = SimulateClosure(cirq.Simulator())
with multiprocessing.Pool() as pool:
pool.map(_simulate, circuits)
Fails! With the very cryptic "TypeError: can't pickle module objects" But I'm not pickleing a module object!! Well, after some digging and a hat tip to https://stackoverflow.com/a/59832602, we actually are. Our old friend np.random
, which is the default random generator. A partial solution is to explicitly construct cirq.Simulator(seed=np.random.RandomState())
, but be very careful; see below.
I agree this is an issue. However, I think a subtle problem with your solution of seed=np.random.RandomState()
will cause the simulator in each process to get the same sequence of random numbers.
Yes, that's very important to point out. In my specific case, it works since I'm just using the simulate()
method (without measurements or any stochasticity). If I could pass something that says "no randomness please" I would, but None
causes it to default to np.random
. I'll update the original post so no one gets the wrong idea.
Could you have each process create its own simulator instead of pickling one?
Otherwise, maybe we extract the RNG from simulator itself and instead pass it explicitly into 'run' etc? Then each process can just create its own rng to use.
I have a tiny bit of experience with this from having hit similar problems in the past. The Python multiprocessing
package uses the (basic) Python pickle
module, which has several limitations like the inability to handle module objects or lambdas. There are alternative serialization packages that don't have those limitations; dill and cloudpickle are probably the most popular, but of course, the problem is how to get Python's multiprocessing
to actually use them.
I think the following covers the most common solutions:
- Use the 3rd-party multiprocess package instead of
multiprocessing
. It usesdill
instead ofpickle
and can handle lambdas and other things. - Monkeypatch
pickle
to use a different pickling function. It needs to be done before importingmultiprocessing
, e.g., like this:import pickle import cloudpickle pickle.Pickler = cloudpickle.Pickler import multiprocessing ... rest of the code ...
- (Possibly) use the shared memory feature introduced in Python 3.8's version of
multiprocessing
to avoid the need for pickling at all.
In terms of what's viable for Cirq, it seems like (3) might not be an option because of the variety of computing backends that might be used by people running Cirq, so that would leave (1) or (2).
Approach (2) is hacky but if the situation where the problem occurs is entirely wrapped up inside Cirq somewhere, and only in a single module, then maybe that wouldn't be too bad?
Otherwise, (1) is some ways cleaner. There are not that many places in Cirq where multiprocessing
is currently used (at least, in this particular repo—haven't checked other repos), so maybe a first pass is to try to see if replacing all cases of multiprocessing
with multiprocess
produces a still-working system without the problem with pickling random.
A couple of concerns about (1) come to mind: first, users might end up using multiprocessing
in their own code. It's not clear if that would be a problem per se. The multiprocess
package is a fork of multiprocessing
with only the serialization parts changed—it's not doing gnarly stuff to Python core functions or whatever, so it seems likely that if a user invoked multiprocessing
in their code, there shouldn't be interactions. Second, both dill and cloudpickle seem to be slower than pickle
. The impact probably depends on what exactly is being pickled, how often, how many, etc. Both of these concerns could undoubtedly be explored in some benchmark tests.
we could also change the constructor's default argument from np.random
to None
and then have if random_state is None: random_state = np.random
You have to be a little bit careful because if you get a bunch of these simulators creating themselves at the same time in each process, they could happen to get seeded to the same state if they draw their seed from the current time.
Did we give up on this problem? Just ran into it myself with cirq.DensityMatrixSimulator
, and it is indeed very annoying.
You have to be a little bit careful because if you get a bunch of these simulators creating themselves at the same time in each process, they could happen to get seeded to the same state if they draw their seed from the current time.
You can generate a list of unique seeds in the main process and pass them zipped with circuits to the mapped worker function. That way each function call would have a unique seed to create its own simulator. In fact, the initial example, if pickleable, would run simulations with the same seed. multiprocessing by default forks child processes, so they would end up with numpy RNGs in identical state.
It's still kindof annoying that you need to generate a bunch of random seeds in the case where you're not actually using any randomness, i.e. if you're calling simulate
. For @jarthurgross , there's even less randomness needed when doing a density matrix simulation :)