trax icon indicating copy to clipboard operation
trax copied to clipboard

Trax doesn't support tuples in TrainTask

Open mtyrolski opened this issue 4 years ago • 0 comments

Description

In TrainTask there is argument labeled_data with description:

    labeled_data: Iterator of batches of labeled data tuples. Each tuple has
          1+ data (input value) tensors followed by 1 label (target value)
          tensor.  All tensors are NumPy ndarrays or their JAX counterparts.

From this description, it follows that we can replace _very_simple_input in training_test.py with following code:

def _very_simple_data(output_dim=1, rep=1):
  """"Returns stream of labeled data that maps small integers to constant pi."""
  inputs_batch = np.arange(8).reshape((8, 1))  # 8 items per batch
  targets_batch = np.pi * np.ones((8, output_dim))
  labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch))
  while True:
    yield (labeled_batch,) * rep

but then we got the following error:

(...)
LayerError: Exception passing through layer Dense_3 (in init):
  layer created in file [...]/trax/supervised/training_test.py, line 441
  layer input shapes: (ShapeDtype{shape:(8, 1), dtype:int64}, ShapeDtype{shape:(8, 1), dtype:float64}, ShapeDtype{shape:(8, 1), dtype:float64})

  File [...]/trax/layers/core.py, line 111, in init_weights_and_state
    shape_w = (input_signature.shape[-1], self._n_units)

AttributeError: 'tuple' object has no attribute 'shape'

Environment information

OS: ubuntu 20.04

trax version: master

$ pip freeze | grep tensor
mesh-tensorflow==0.1.18
tensorboard==2.4.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.4.1
tensorflow-datasets==4.2.0
tensorflow-estimator==2.4.0
tensorflow-hub==0.11.0
tensorflow-metadata==0.27.0
tensorflow-text==2.4.3


$ pip freeze | grep jax
jax==0.2.9
jaxlib==0.1.61


$ python -V
Python 3.8.5

Steps to reproduce:

just replace _very_simple_data and run any test which uses _very_simple_data in training_test.py

mtyrolski avatar Mar 18 '21 22:03 mtyrolski