jax icon indicating copy to clipboard operation
jax copied to clipboard

[TPU] TypeError with JAX 0.3.17 in Google Collab

Open camierjs opened this issue 2 years ago • 1 comments

Description

Running in Google Collab with JAX-0.3.17 the following line on TPUs: jax.device_get(jax.device_put_replicated(jnp.arange(10), jax_devices))

throws the following error:

TypeError: int() argument must be a string, a bytes-like object or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'

JAX < 0.3.17 produces the expected output:

array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=int32)

Here is the self-contained notebook that reproduces this issue GitHub Gist.

Linked to an issue opened in Google Colab.

What jax/jaxlib version are you using?

v0.3.17

Which accelerator(s) are you using?

TPU

Additional System Info

Google Colab

camierjs avatar Sep 12 '22 15:09 camierjs

Hi - I suspect this may be related to the general errors people have been seeing on Colab TPU for the last few weeks; see https://github.com/googlecolab/colabtools/issues/3009

It may be that the TPU driver version mentioned there is also causing the issue you see here

jakevdp avatar Sep 12 '22 16:09 jakevdp

This should be fixed in the most recent jax & jaxlib release. Thanks!

jakevdp avatar Oct 07 '22 03:10 jakevdp

I am getting this error currently, even after beginning my code in colab with

# get the latest JAX and jaxlib
!pip install --upgrade -q jax jaxlib

# Colab runtime set to TPU accel
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# TPU driver as backend for JAX
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

xaviergonzalez avatar Dec 28 '22 02:12 xaviergonzalez

I would suggest not doing all of this manually, instead use

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

The advice above was for working around an issue with this built-in setup. If that does not work, then please open a bug showing what error you are seeing.

jakevdp avatar Dec 28 '22 17:12 jakevdp

Still having this issue... any developments?

wbrenton avatar Jan 18 '23 22:01 wbrenton

Can you say more about where you're seeing this issue? I just ran the following on a fresh Colab TPU runtime and got the expected output:

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

import jax
import jax.numpy as jnp
jax_devices = jax.devices()
print(jax_devices)

jax.device_get(jax.device_put_replicated(jnp.arange(10), jax_devices))

jakevdp avatar Jan 18 '23 22:01 jakevdp

Message is "UnfilteredStackTrace: TypeError: JAX encountered invalid PRNG key data: expected key_data to have ndim, shape, and dtype attributes. Got <jaxlib.tpu_client_extension.PyTpuBuffer object at 0x7f764be19cc0>

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified."

Code works fine if running it on CPU.

Thanks for lightning fast reply. I don't understand how you do it

wbrenton avatar Jan 18 '23 22:01 wbrenton

Thanks - that looks unrelated to this TPU issue. It would be helpful if you could open a new discussion, and if possible include a minimal reproducible example of the code that led to the error: it's hard to guess what might have caused it from the traceback alone.

jakevdp avatar Jan 18 '23 22:01 jakevdp