jax icon indicating copy to clipboard operation
jax copied to clipboard

Pickling a JAX array does not preserve its original device

Open ayaka14732 opened this issue 3 years ago • 12 comments

Description

What I hope to achieve

The JAX arrays are created in a subprocess, evoked by multiprocessing.map, on CPU. I hope that when the arrays are sent to the main process, they are still on the same device (i.e. CPU). However, after they are sent to the main process, they are placed on the default device (i.e. GPU) instead.

import jax
import jax.numpy as np
import multiprocessing

def f():
    jax.config.update('jax_platforms', 'cpu')
    a = np.zeros((2,))
    print(a.device())  # TFRT_CPU_0
    return a

def main():
    ctx = multiprocessing.get_context('spawn')
    with ctx.Pool(1) as p:
        a = p.starmap(f, ((),))[0]
    print(a.device())  # gpu:0, expected TFRT_CPU_0

if __name__ == '__main__':
    main()

This also applies to the underlying pickle module

import jax
import jax.numpy as np
import pickle

jax.config.update('jax_platforms', 'cpu')

a = np.zeros((2,))
print(a.device())  # TFRT_CPU_0

with open('1.dat', 'wb') as f:
    pickle.dump(a, f)
import pickle

with open('1.dat', 'rb') as f:
    a = pickle.load(f)
print(a.device())  # gpu:0, expected TFRT_CPU_0

What jax/jaxlib version are you using?

jax 0.3.17, jaxlib 0.3.15+cuda11.cudnn82

Which accelerator(s) are you using?

GPU

Additional System Info

Python 3.10.6, Arch Linux x86_64

ayaka14732 avatar Sep 07 '22 13:09 ayaka14732

This is expected behavior; see the description from the CHANGELOG entry when pickle support was added: https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0314-june-27-2022

The issue is that pickling and unpickling need not happen in the same environment, and so attempting to preserve the original device can be problematic; see https://github.com/google/jax/pull/10659 for further discussion, and let us know if you have any suggestions.

jakevdp avatar Sep 07 '22 15:09 jakevdp

@jakevdp Thank you! I agree that attempting to preserve the original device can be problematic. My suggestion is that we should give priority to implementing a with block to specify the target device. For example: in PyTorch we can do:

with torch.cuda.device(2):  # specify default device
   ...  # create array on the default device

I remember that this issue was discussed in https://github.com/google/jax/issues/8879

ayaka14732 avatar Sep 12 '22 03:09 ayaka14732

https://github.com/google/jax/pull/9118 may have added this context manager already, but I'm not sure. Checking with @skye ...

mattjj avatar Sep 12 '22 05:09 mattjj

Yes, I believe doing something like with jax.default_device(jax.devices("cpu")[0]) should work. Please let me know if it doesn't, I haven't tried it with pickle.

skye avatar Sep 12 '22 19:09 skye

Unpickling of an array uses jax.device_put with no device argument, so I believe the default default device context manager should do the right thing.

jakevdp avatar Sep 12 '22 19:09 jakevdp

To put a fine point on it, the analogue of code like this:

with torch.cuda.device(2):  # specify default device
   ...  # create array on the default device

Is something like this:

devices = jax.devices()
with jax.default_device(devices[2]):
  ... # create or load (e.g. unpickle) array on default device

I think this issue is resolved, but please reopen if not!

mattjj avatar Sep 12 '22 19:09 mattjj

@mattjj Thank you for the reply, but this issue is not solved. jax.default_device works when creating the array, but does not work when loading the array from pickle:

import jax
import jax.numpy as np
import multiprocessing

def f():
    jax.config.update('jax_platforms', 'cpu')
    a = np.zeros((2,))
    print(a.device())  # TFRT_CPU_0
    return a

def main():
    device_cpu = jax.devices('cpu')[0]
    with jax.default_device(device_cpu):
        a = np.zeros((2,))
        print(a.device())  # TFRT_CPU_0

        ctx = multiprocessing.get_context('spawn')
        with ctx.Pool(1) as p:
            a = p.starmap(f, ((),))[0]
        print(a.device())  # gpu:0, expected TFRT_CPU_0

if __name__ == '__main__':
    main()

ayaka14732 avatar Sep 15 '22 04:09 ayaka14732

I found the cause of the problem

Unpickling of an array uses jax.device_put with no device argument, so I believe the default default device context manager should do the right thing.

But actually this does not work as expected:

import jax
import numpy as onp

device_cpu = jax.devices('cpu')[0]
with jax.default_device(device_cpu):
    a = onp.array([1., 2.])

    b = jax.device_put(a)
    print(b.device())  # gpu:0, not TFRT_CPU_0

ayaka14732 avatar Sep 15 '22 05:09 ayaka14732

@mattjj Please reopen this

ayaka14732 avatar Sep 17 '22 00:09 ayaka14732

Reopening and assigning to @skye – it looks like device_put does not respect the default device context. Is this intended?

jakevdp avatar Sep 19 '22 18:09 jakevdp

I think that is not intended.

(@ayaka14732 thanks for following up and for the pings. Sorry I didn't notice your replies until now!)

mattjj avatar Sep 19 '22 18:09 mattjj

Sorry for prematurely closing this. I really should've added a test, and then I would've seen that it doesn't work!

On CPU at least, it looks like xla_extension.Client.buffer_from_pyval(x, None) doesn't respect the default device...

mattjj avatar Sep 19 '22 18:09 mattjj

Are there any updates on this? I'm wondering if I may be running into a related problem..

TL;DR recently updated jax, and code that used to run fine---using multiprocessing (forkserver) + jax (with default device / all arrays placed on CPU)---is now crashing due to memory overload on GPU. Monitoring the GPU, I can see that it starts by allocating 10Mb per process, which over time balloons to 100-400+, and ultimately crashes.

Thinking this might be caused by pickle (underlying multiprocessing) not playing well with device placement

EDIT: confirmed that arrays are being passed on CPU device, but are on GPU once inside a pickled function. fixed by moving the jax.default_device statement inside rather than outside the picked function. I wonder if this rises to the level of a "Sharp Bit" ? (and/or, ideally would be tallied somewhere on a list of best practices for interfacing with multiprocessing)

meereeum avatar Apr 05 '23 06:04 meereeum