jax
jax copied to clipboard
jit cache key should include device (really, device type)
from jax import device_put, devices, jit
import jax.numpy as np
import numpy as onp
def f(x):
return np.sum(x)
f0 = jit(f, device=devices()[0])
f1 = jit(f, device=devices()[1])
data = onp.random.rand(10000)
x0 = device_put(data, device=devices()[0])
x1 = device_put(data, device=devices()[1])
print("f on gpu 0:", f0(x0))
print("f on gpu 1:", f1(x1))
As I run this code on my machine, there are errors like this
f on gpu 0: 4949.328
2020-04-15 13:00:27.288577: E external/org_tensorflow/tensorflow/compiler/xla/python/local_client.cc:758] Execution of replica 0 failed: Invalid argument: executable is built for device CUDA:0 of type "GeForce RTX 2080 Ti"; cannot run it on device CUDA:1 of type "Tesla P100-PCIE-16GB"
Traceback (most recent call last):
File "test3.py", line 21, in <module>
print("f on gpu 1:", f1(x1))
File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 150, in f_jitted
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 592, in call_bind
outs = primitive.impl(f, *args, **params)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 402, in _xla_call_impl
return compiled_fun(*args)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 486, in _execute_compiled
out_bufs = compiled.Execute(input_bufs).destructure()
RuntimeError: Invalid argument: executable is built for device CUDA:0 of type "GeForce RTX 2080 Ti"; cannot run it on device CUDA:1 of type "Tesla P100-PCIE-16GB"
Could you give me some advice?
Originally posted by @caihao in https://github.com/google/jax/issues/1899#issuecomment-614026930
Hi @mattjj
Looks like this issue has been resolved in later versions of JAX. I tried to reproduce this issue with the latest JAX version 0.4.26 on cloud VM with 4 T4 GPUs. But it works without any error now. Please find the below screenshot for reference.
Thank you.