Pickling a JAX array does not preserve its original device
Description
What I hope to achieve
The JAX arrays are created in a subprocess, evoked by multiprocessing.map, on CPU. I hope that when the arrays are sent to the main process, they are still on the same device (i.e. CPU). However, after they are sent to the main process, they are placed on the default device (i.e. GPU) instead.
import jax
import jax.numpy as np
import multiprocessing
def f():
jax.config.update('jax_platforms', 'cpu')
a = np.zeros((2,))
print(a.device()) # TFRT_CPU_0
return a
def main():
ctx = multiprocessing.get_context('spawn')
with ctx.Pool(1) as p:
a = p.starmap(f, ((),))[0]
print(a.device()) # gpu:0, expected TFRT_CPU_0
if __name__ == '__main__':
main()
This also applies to the underlying pickle module
import jax
import jax.numpy as np
import pickle
jax.config.update('jax_platforms', 'cpu')
a = np.zeros((2,))
print(a.device()) # TFRT_CPU_0
with open('1.dat', 'wb') as f:
pickle.dump(a, f)
import pickle
with open('1.dat', 'rb') as f:
a = pickle.load(f)
print(a.device()) # gpu:0, expected TFRT_CPU_0
What jax/jaxlib version are you using?
jax 0.3.17, jaxlib 0.3.15+cuda11.cudnn82
Which accelerator(s) are you using?
GPU
Additional System Info
Python 3.10.6, Arch Linux x86_64
This is expected behavior; see the description from the CHANGELOG entry when pickle support was added: https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0314-june-27-2022
The issue is that pickling and unpickling need not happen in the same environment, and so attempting to preserve the original device can be problematic; see https://github.com/google/jax/pull/10659 for further discussion, and let us know if you have any suggestions.
@jakevdp Thank you! I agree that attempting to preserve the original device can be problematic. My suggestion is that we should give priority to implementing a with block to specify the target device. For example: in PyTorch we can do:
with torch.cuda.device(2): # specify default device
... # create array on the default device
I remember that this issue was discussed in https://github.com/google/jax/issues/8879
https://github.com/google/jax/pull/9118 may have added this context manager already, but I'm not sure. Checking with @skye ...
Yes, I believe doing something like with jax.default_device(jax.devices("cpu")[0]) should work. Please let me know if it doesn't, I haven't tried it with pickle.
Unpickling of an array uses jax.device_put with no device argument, so I believe the default default device context manager should do the right thing.
To put a fine point on it, the analogue of code like this:
with torch.cuda.device(2): # specify default device
... # create array on the default device
Is something like this:
devices = jax.devices()
with jax.default_device(devices[2]):
... # create or load (e.g. unpickle) array on default device
I think this issue is resolved, but please reopen if not!
@mattjj Thank you for the reply, but this issue is not solved. jax.default_device works when creating the array, but does not work when loading the array from pickle:
import jax
import jax.numpy as np
import multiprocessing
def f():
jax.config.update('jax_platforms', 'cpu')
a = np.zeros((2,))
print(a.device()) # TFRT_CPU_0
return a
def main():
device_cpu = jax.devices('cpu')[0]
with jax.default_device(device_cpu):
a = np.zeros((2,))
print(a.device()) # TFRT_CPU_0
ctx = multiprocessing.get_context('spawn')
with ctx.Pool(1) as p:
a = p.starmap(f, ((),))[0]
print(a.device()) # gpu:0, expected TFRT_CPU_0
if __name__ == '__main__':
main()
I found the cause of the problem
Unpickling of an array uses jax.device_put with no device argument, so I believe the default default device context manager should do the right thing.
But actually this does not work as expected:
import jax
import numpy as onp
device_cpu = jax.devices('cpu')[0]
with jax.default_device(device_cpu):
a = onp.array([1., 2.])
b = jax.device_put(a)
print(b.device()) # gpu:0, not TFRT_CPU_0
@mattjj Please reopen this
Reopening and assigning to @skye – it looks like device_put does not respect the default device context. Is this intended?
I think that is not intended.
(@ayaka14732 thanks for following up and for the pings. Sorry I didn't notice your replies until now!)
Sorry for prematurely closing this. I really should've added a test, and then I would've seen that it doesn't work!
On CPU at least, it looks like xla_extension.Client.buffer_from_pyval(x, None) doesn't respect the default device...
Are there any updates on this? I'm wondering if I may be running into a related problem..
TL;DR recently updated jax, and code that used to run fine---using multiprocessing (forkserver) + jax (with default device / all arrays placed on CPU)---is now crashing due to memory overload on GPU. Monitoring the GPU, I can see that it starts by allocating 10Mb per process, which over time balloons to 100-400+, and ultimately crashes.
Thinking this might be caused by pickle (underlying multiprocessing) not playing well with device placement
EDIT: confirmed that arrays are being passed on CPU device, but are on GPU once inside a pickled function. fixed by moving the jax.default_device statement inside rather than outside the picked function. I wonder if this rises to the level of a "Sharp Bit" ? (and/or, ideally would be tallied somewhere on a list of best practices for interfacing with multiprocessing)