jax icon indicating copy to clipboard operation
jax copied to clipboard

jit cache key should include device (really, device type)

Open mattjj opened this issue 2 years ago • 1 comments

from jax import device_put, devices, jit
import jax.numpy as np
import numpy as onp


def f(x):
    return np.sum(x)


f0 = jit(f, device=devices()[0])
f1 = jit(f, device=devices()[1])

data = onp.random.rand(10000)

x0 = device_put(data, device=devices()[0])
x1 = device_put(data, device=devices()[1])

print("f on gpu 0:", f0(x0))
print("f on gpu 1:", f1(x1))

As I run this code on my machine, there are errors like this

f on gpu 0: 4949.328
2020-04-15 13:00:27.288577: E external/org_tensorflow/tensorflow/compiler/xla/python/local_client.cc:758] Execution of replica 0 failed: Invalid argument: executable is built for device CUDA:0 of type "GeForce RTX 2080 Ti"; cannot run it on device CUDA:1 of type "Tesla P100-PCIE-16GB"
Traceback (most recent call last):
  File "test3.py", line 21, in <module>
    print("f on gpu 1:", f1(x1))
  File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 150, in f_jitted
    out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 592, in call_bind
    outs = primitive.impl(f, *args, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 402, in _xla_call_impl
    return compiled_fun(*args)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 486, in _execute_compiled
    out_bufs = compiled.Execute(input_bufs).destructure()
RuntimeError: Invalid argument: executable is built for device CUDA:0 of type "GeForce RTX 2080 Ti"; cannot run it on device CUDA:1 of type "Tesla P100-PCIE-16GB"

Could you give me some advice?

Originally posted by @caihao in https://github.com/google/jax/issues/1899#issuecomment-614026930

mattjj avatar Oct 05 '22 19:10 mattjj

Hi @mattjj

Looks like this issue has been resolved in later versions of JAX. I tried to reproduce this issue with the latest JAX version 0.4.26 on cloud VM with 4 T4 GPUs. But it works without any error now. Please find the below screenshot for reference.

image

Thank you.

rajasekharporeddy avatar Apr 18 '24 18:04 rajasekharporeddy