jax
jax copied to clipboard
[TPU] TypeError with JAX 0.3.17 in Google Collab
Description
Running in Google Collab with JAX-0.3.17 the following line on TPUs:
jax.device_get(jax.device_put_replicated(jnp.arange(10), jax_devices))
throws the following error:
TypeError: int() argument must be a string, a bytes-like object or a number,
not 'jaxlib.tpu_client_extension.PyTpuBuffer'
JAX < 0.3.17 produces the expected output:
array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=int32)
Here is the self-contained notebook that reproduces this issue GitHub Gist.
Linked to an issue opened in Google Colab.
What jax/jaxlib version are you using?
v0.3.17
Which accelerator(s) are you using?
TPU
Additional System Info
Google Colab
Hi - I suspect this may be related to the general errors people have been seeing on Colab TPU for the last few weeks; see https://github.com/googlecolab/colabtools/issues/3009
It may be that the TPU driver version mentioned there is also causing the issue you see here
This should be fixed in the most recent jax & jaxlib release. Thanks!
I am getting this error currently, even after beginning my code in colab with
# get the latest JAX and jaxlib
!pip install --upgrade -q jax jaxlib
# Colab runtime set to TPU accel
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
# TPU driver as backend for JAX
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
I would suggest not doing all of this manually, instead use
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
The advice above was for working around an issue with this built-in setup. If that does not work, then please open a bug showing what error you are seeing.
Still having this issue... any developments?
Can you say more about where you're seeing this issue? I just ran the following on a fresh Colab TPU runtime and got the expected output:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
import jax.numpy as jnp
jax_devices = jax.devices()
print(jax_devices)
jax.device_get(jax.device_put_replicated(jnp.arange(10), jax_devices))
Message is "UnfilteredStackTrace: TypeError: JAX encountered invalid PRNG key data: expected key_data to have ndim, shape, and dtype attributes. Got <jaxlib.tpu_client_extension.PyTpuBuffer object at 0x7f764be19cc0>
The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified."
Code works fine if running it on CPU.
Thanks for lightning fast reply. I don't understand how you do it
Thanks - that looks unrelated to this TPU issue. It would be helpful if you could open a new discussion, and if possible include a minimal reproducible example of the code that led to the error: it's hard to guess what might have caused it from the traceback alone.