trax
trax copied to clipboard
TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'
Description
Hi, I am trying to follow this tutorial: https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb Setting the runtime to TPU on Colab used to work couple of days ago. But now it crashes with error:
TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'
This happens at this step: training_loop = training.Loop(model,.....
Environment information
OS:
NAME="Ubuntu"
VERSION="18.04.6 LTS (Bionic Beaver)"
ID=ubuntu
ID_LIKE=debian
PRETTY_NAME="Ubuntu 18.04.6 LTS"
VERSION_ID="18.04"
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
VERSION_CODENAME=bionic
UBUNTU_CODENAME=bionic
$ pip freeze | grep trax
# trax==1.4.1
$ pip freeze | grep tensor
# tensorboard==2.10.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.10.0
tensorflow-datasets==4.6.0
tensorflow-estimator==2.10.0
tensorflow-gcs-config==2.8.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.26.0
tensorflow-metadata==1.10.0
tensorflow-probability==0.16.0
tensorflow-text==2.10.0
$ pip freeze | grep jax
# jax==0.3.17
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.15+cuda11.cudnn805-cp37-none-manylinux2014_x86_64.whl
$ python -V
# Python 3.7.13
For bugs: reproduction and error logs
# Steps to reproduce:
https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb
...
# Error logs:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-8-2021642a85f0>](https://localhost:8080/#) in <module>
9 train_task,
10 eval_tasks=[eval_task],
---> 11 output_dir=output_dir)
16 frames
[/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
278
279 # Create the optimizer for the training loss function.
--> 280 self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
281
282 # Sync layers weights/state in memory effcient trainer layers.
[/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in <genexpr>(.0)
278
279 # Create the optimizer for the training loss function.
--> 280 self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
281
282 # Sync layers weights/state in memory effcient trainer layers.
[/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in _init_trainer(self, task)
348 task.optimizer.tree_init(model_in_training.weights)
349 return optimizers.Trainer(
--> 350 model_in_training, task.optimizer, adasum=self._adasum)
351 # In the memory-efficient path, we initialize the model here.
352 blocks, loss_layer = optimizers.trainer.extract_reversible_blocks(
[/usr/local/lib/python3.7/dist-packages/trax/optimizers/trainer.py](https://localhost:8080/#) in __init__(self, model_with_loss, optimizer, n_devices, adasum)
57 # optimizer slots and opt_params may need to be replicated
58 self._slots, self._opt_params = tl.on_cpu(tl.for_n_devices(
---> 59 (self._optimizer.slots, self._optimizer.opt_params), self._n_devices))
60
61 # accelerated version of model+loss to replicate weights and state
[/usr/local/lib/python3.7/dist-packages/trax/layers/acceleration.py](https://localhost:8080/#) in on_cpu(x)
250 """Puts ``x`` in CPU memory in JAX."""
251 if fastmath.is_backend(fastmath.Backend.JAX):
--> 252 return jax.device_put(x, jax.devices('cpu')[0])
253 else:
254 return x
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in device_put(x, device)
2722 """
2723 with config_explicit_device_put_scope():
-> 2724 return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
2725
2726
[/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_map(f, tree, is_leaf, *rest)
203 leaves, treedef = tree_flatten(tree, is_leaf)
204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
206
207 def build_tree(treedef, xs):
[/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in <genexpr>(.0)
203 leaves, treedef = tree_flatten(tree, is_leaf)
204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
206
207 def build_tree(treedef, xs):
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in <lambda>(y)
2722 """
2723 with config_explicit_device_put_scope():
-> 2724 return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
2725
2726
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, *args, **params)
323 assert (not config.jax_enable_checks or
324 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 325 return self.bind_with_trace(find_top_trace(args), args, params)
326
327 def bind_with_trace(self, trace, args, params):
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
326
327 def bind_with_trace(self, trace, args, params):
--> 328 out = trace.process_primitive(self, map(trace.full_raise, args), params)
329 return map(full_lower, out) if self.multiple_results else full_lower(out)
330
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
684
685 def process_primitive(self, primitive, tracers, params):
--> 686 return primitive.impl(*tracers, **params)
687
688 def process_call(self, primitive, f, tracers, params):
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_impl(x, device)
1219 raise TypeError(
1220 f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
-> 1221 return aval_to_result_handler(device, a)(None, *device_put(x, device))
1222
1223
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in device_put(x, device)
1113 x = xla.canonicalize_dtype(x)
1114 try:
-> 1115 return device_put_handlers[type(x)](x, device)
1116 except KeyError as err:
1117 raise TypeError(f"No device_put handler for type: {type(x)}") from err
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_array(x, device)
1124 if x.dtype == dtypes.float0:
1125 x = np.zeros(x.shape, dtype=np.dtype(bool))
-> 1126 return (backend.buffer_from_pyval(x, device),)
1127
1128 def _device_put_scalar(x, device):
[/usr/local/lib/python3.7/dist-packages/jax/_src/device_array.py](https://localhost:8080/#) in __array__(self, dtype, context)
264
265 def __array__(self, dtype=None, context=None):
--> 266 return np.asarray(self._value, dtype=dtype)
267
268 setattr(device_array, "__array__", __array__)
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py](https://localhost:8080/#) in _sda_value(self)
803 npy_value = np.empty(self.aval.shape, self.aval.dtype)
804 for i in self.one_replica_buffer_indices:
--> 805 npy_value[self.indices[i]] = np.asarray(self.device_buffers[i])
806 self._npy_value = npy_value
807 return self._npy_value
TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'
...