trax
trax copied to clipboard
AttributeError: 'EvalTask' object has no attribute 'init'
Description
...
Environment information
OS: <your answer here>
$ pip freeze | grep trax
trax==1.3.6
$ pip freeze | grep tensor
mesh-tensorflow==0.1.16
tensor2tensor==1.15.7
tensorboard==2.3.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.0
tensorflow-addons==0.11.2
tensorflow-datasets==3.2.1
tensorflow-estimator==2.3.0
tensorflow-gan==2.0.0
tensorflow-hub==0.9.0
tensorflow-metadata==0.27.0
tensorflow-probability==0.7.0
tensorflow-text==2.3.0
$ pip freeze | grep jax
jax==0.1.75
jaxlib==0.1.52
$ python -V
Python 3.7.4
For bugs: reproduction and error logs
# Steps to reproduce:
lr_schedule = trax.lr.warmup_and_rsqrt_decay(400, 0.01)
def train_model(Siamese, TripletLoss, lr_schedule, train_generator=train_generator, val_generator=val_generator, output_dir='model/'):
output_dir = os.path.expanduser(output_dir)
train_task = training.TrainTask(
labeled_data=train_generator,
loss_layer=TripletLoss(),
optimizer=trax.optimizers.Adam(learning_rate = 0.01),
lr_schedule=lr_schedule,
)
eval_task = training.EvalTask(
labeled_data=val_generator,
metrics=[TripletLoss()],
)
training_loop = training.Loop(Siamese(),
train_task,
eval_task,
output_dir,
random_seed=31)
return training_loop
train_steps = 5
training_loop = train_model(Siamese, TripletLoss, lr_schedule)
training_loop.run(train_steps)
# Error logs:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-33-b8ba7e43cd69> in <module>
1 train_steps = 5
----> 2 training_loop = train_model(Siamese, TripletLoss, lr_schedule)
3 training_loop.run(train_steps)
<ipython-input-32-30d58e0f4b48> in train_model(Siamese, TripletLoss, lr_schedule, train_generator, val_generator, output_dir)
39 eval_task,
40 output_dir,
---> 41 random_seed=31)
42
43 return training_loop
~/opt/anaconda3/lib/python3.7/site-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer)
201 self._model.init(self._batch_signature)
202 self._eval_model.rng = self.new_rng()
--> 203 self._eval_model.init(self._batch_signature)
204
205 # To handle the above case (i.e. random_seed = None), we psum the weights
AttributeError: 'EvalTask' object has no attribute 'init'