trax
trax copied to clipboard
Trax doesn't support tuples in TrainTask
Description
In TrainTask there is argument labeled_data with description:
labeled_data: Iterator of batches of labeled data tuples. Each tuple has
1+ data (input value) tensors followed by 1 label (target value)
tensor. All tensors are NumPy ndarrays or their JAX counterparts.
From this description, it follows that we can replace _very_simple_input in training_test.py with following code:
def _very_simple_data(output_dim=1, rep=1):
""""Returns stream of labeled data that maps small integers to constant pi."""
inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch
targets_batch = np.pi * np.ones((8, output_dim))
labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch))
while True:
yield (labeled_batch,) * rep
but then we got the following error:
(...)
LayerError: Exception passing through layer Dense_3 (in init):
layer created in file [...]/trax/supervised/training_test.py, line 441
layer input shapes: (ShapeDtype{shape:(8, 1), dtype:int64}, ShapeDtype{shape:(8, 1), dtype:float64}, ShapeDtype{shape:(8, 1), dtype:float64})
File [...]/trax/layers/core.py, line 111, in init_weights_and_state
shape_w = (input_signature.shape[-1], self._n_units)
AttributeError: 'tuple' object has no attribute 'shape'
Environment information
OS: ubuntu 20.04
trax version: master
$ pip freeze | grep tensor
mesh-tensorflow==0.1.18
tensorboard==2.4.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.4.1
tensorflow-datasets==4.2.0
tensorflow-estimator==2.4.0
tensorflow-hub==0.11.0
tensorflow-metadata==0.27.0
tensorflow-text==2.4.3
$ pip freeze | grep jax
jax==0.2.9
jaxlib==0.1.61
$ python -V
Python 3.8.5
Steps to reproduce:
just replace _very_simple_data and run any test which uses _very_simple_data in training_test.py