trax
trax copied to clipboard
TypeError: unsupported operand type(s) for ==: 'Array' and 'tuple'
Description
...
Environment information
OS: <your answer here>
$ pip freeze | grep trax
trax==1.4.1
$ pip freeze | grep tensor
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.8.2
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.30.0
tensorflow-metadata==1.12.0
tensorflow-text==2.11.0
$ pip freeze | grep jax
jax==0.4.2
jaxlib==0.4.2
$ python -V
Python 3.10.9+
For bugs: reproduction and error logs
# Steps to reproduce:
...
training_loop = train_model(model, train_task, eval_task, 100, output_dir_expand)
# Error logs:
...
TypeError Traceback (most recent call last)
Cell In[41], line 1
----> 1 training_loop = train_model(model, train_task, eval_task, 100, output_dir_expand)
Cell In[40], line 15, in train_model(classifier, train_task, eval_task, n_steps, output_dir)
4 '''
5 Input:
6 classifier - the model you are building
(...)
12 trainer - trax trainer
13 '''
14 ### START CODE HERE (Replace instances of 'None' with your code) ###
---> 15 training_loop = training.Loop(
16 model=classifier, # The learning model
17 tasks=train_task, # The training task
18 eval_tasks=[eval_task], # The evaluation task
19 output_dir=output_dir) # The output directory
21 training_loop.run(n_steps = n_steps)
22 ### END CODE HERE ###
23
24 # Return the training_loop, since it has the model.
File ~/.pyenv/versions/3.10-dev/envs/trax/lib/python3.10/site-packages/trax/supervised/training.py:294, in Loop.__init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
289 layer.weights, layer.state = tl.on_cpu(self._unreplicate(
290 _make_weights_and_state_same_across_hosts(
291 self._for_n_devices(weights_and_state))))
293 # Load checkpoint if it exists.
--> 294 self.load_checkpoint()
296 # Prepare eval components.
297 self._eval_at = eval_at or default_at
File ~/.pyenv/versions/3.10-dev/envs/trax/lib/python3.10/site-packages/trax/supervised/training.py:944, in Loop.load_checkpoint(self, directory, filename)
940 for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']):
941 matched_flat_slots = _match_by_shape(
942 self._to_bits(_flatten_and_remove_empty(trainer.slots)),
943 _flatten_and_remove_empty(slots))
--> 944 matched_slots, _ = fastmath.tree_unflatten(
945 self._from_bits(matched_flat_slots),
946 trainer.slots, copy_from_tree=[None, ()])
947 trainer.slots = matched_slots
948 self._step = d['step']
File ~/.pyenv/versions/3.10-dev/envs/trax/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
242 new_tree, rest = [], flat
243 for t in tree:
--> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
245 new_tree.append(new_t)
246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree
File ~/.pyenv/versions/3.10-dev/envs/trax/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
242 new_tree, rest = [], flat
243 for t in tree:
--> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
245 new_tree.append(new_t)
246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree
File ~/.pyenv/versions/3.10-dev/envs/trax/lib/python3.10/site-packages/trax/fastmath/numpy.py:239, in tree_unflatten(flat, tree, copy_from_tree)
216 def tree_unflatten(flat, tree, copy_from_tree=None):
217 """Unflatten a list into a tree given the tree shape as second argument.
218
219 Args:
(...)
237 more were provided than the number of leaves of tree (useful for recursion).
238 """
--> 239 if copy_from_tree is not None and tree in copy_from_tree:
240 return tree, flat
241 if isinstance(tree, (list, tuple)):
File ~/.pyenv/versions/3.10-dev/envs/trax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4972, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
4970 return binary_op(*args)
4971 if isinstance(other, _rejected_binop_types):
-> 4972 raise TypeError(f"unsupported operand type(s) for {opchar}: "
4973 f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
4974 return NotImplemented
TypeError: unsupported operand type(s) for ==: 'Array' and 'tuple'
Yesterday I encountered an error like that, a temporary solution can delete the model file that was created, and run again, the error will be gone
I also had this, but I actually would like to be able to load the last checkpoint so deleting the model is not a solution for me. Is there any way to resume training with a saved checkpoint?