trax
trax copied to clipboard
AttributeError: 'function' object has no attribute 'n_steps_per_checkpoint' for NLP Machine translation model
Description
Facing error when trying to run training.Loop
Environment information
OS: Ubuntu 22.04
$ pip freeze | grep trax
trax==1.4.1
$ pip freeze | grep tensor
# your output here
$ pip freeze | grep jax
keras @ file:///home/builder/mesters/opt/envs/tensorflow/conda-bld/keras_1682445665871/work/keras-2.12.0-py2.py3-none-any.whl
safetensors==0.3.2
tensorboard @ file:///home/builder/mesters/opt/envs/tensorflow/conda-bld/tensorboard_1682445826165/work/tensorboard-2.12.1-py3-none-any.whl
tensorboard-data-server @ file:///croot/tensorboard-data-server_1681498183723/work/tensorboard_data_server-0.7.0-py3-none-manylinux2014_x86_64.whl
tensorboard-plugin-wit==1.6.0
tensorflow @ file:///home/builder/mesters/opt/envs/tensorflow/conda-bld/tensorflow-base_1682961422577/work/tensorflow_pkg/tensorflow-2.12.0-cp311-cp311-linux_x86_64.whl
tensorflow-datasets==4.9.2
tensorflow-estimator @ file:///home/builder/mesters/opt/envs/tensorflow/conda-bld/tensorflow-estimator_1682445976941/work/tensorflow_estimator-2.12.0-py2.py3-none-any.whl
tensorflow-hub==0.14.0
tensorflow-io-gcs-filesystem==0.33.0
tensorflow-metadata==1.14.0
tensorflow-text==2.12.1
tensorrt==8.6.1.post1
tensorrt-bindings==8.6.1
tensorrt-libs==8.6.1
$ python -V
3.11.4
def nmt_attention_model(input_vocab_size:int=33300, target_vocab_size:int=33300, d_model:int=1024,
n_encoder_layers:int=2, n_decoder_layers:int=2, n_attn_heads: int = 1,
dropout: float=0.0, mode:str="train") -> tl.Serial:
"""Returns an LSTM sequence-to-sequence model with attention."""
inp_encoder = encoder_fn(input_vocab_size, d_model, n_encoder_layers)
pre_attn_decoder = pre_attention_decoder(mode, target_vocab_size,d_model=d_model, )
return tl.Serial(
tl.Select([0,1,0,1]),
tl.Parallel(inp_encoder, pre_attn_decoder),
tl.Fn("CreateAttnInputs", create_attention_inps, n_out=4),
# nest it inside a Residual layer to add to the pre-attention decoder activations(i.e. queries)
tl.Residual(tl.AttentionQKV(d_model,n_heads=n_attn_heads,
dropout=dropout,mode=mode)),
#dropping mask (since there are 3 inputs, activations, mask and target tokens)
tl.Select([0,2]),
[tl.LSTM(d_model) for _ in range(n_decoder_layers)],
tl.Dense(target_vocab_size),
tl.LogSoftmax()
)
def train_fun(train_batch_stream):
return training.TrainTask(
labeled_data=train_batch_stream,
loss=tl.CrossEntropyLoss(),
optimizer=trax.optimizers.Adam(0.01),
lr=trax.lr.warmup_and_rsqrt_decay(1000, 0.01),
n_steps_per_checkpoint=20
)
def eval_fun(eval_batch_stream):
return training.EvalTask(
labeled_data=eval_batch_stream,
metrics=[tl.CrossEntropyLoss(), tl.Accuracy()]
)
t = training.Loop(nmt_attention_model(mode='train'),
train_fun,
eval_tasks=[eval_fun],
output_dir=output_dir)
# Error logs:
AttributeError Traceback (most recent call last)
File [~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/supervised/training.py:216](https://file+.vscode-resource.vscode-cdn.net/home/sumit/Documents/MyWorkspace/NLP/AttentionModels/NMT/~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/supervised/training.py:216), in Loop.__init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
213 assert len(tasks) == 1, 'only single task supported for now'
214 self._eval_model = model
--> 216 default_at = _at_step_1_and_every_nth_step(tasks[0].n_steps_per_checkpoint)
217 permanent_default_at = _at_step_1_and_every_nth_step(
218 tasks[0].n_steps_per_permanent_checkpoint)
219 if output_dir is not None:
AttributeError: 'function' object has no attribute 'n_steps_per_checkpoint'
It got resolved, I didnt call the train_fun and directly used the function in Loop. However, I after resolving this I am facing a new error:
File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/supervised/training.py:294, in Loop.init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks) 289 layer.weights, layer.state = tl.on_cpu(self._unreplicate( 290 _make_weights_and_state_same_across_hosts( 291 self._for_n_devices(weights_and_state)))) 293 # Load checkpoint if it exists. --> 294 self.load_checkpoint() 296 # Prepare eval components. 297 self._eval_at = eval_at or default_at
File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/supervised/training.py:944, in Loop.load_checkpoint(self, directory, filename) 940 for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']): 941 matched_flat_slots = _match_by_shape( 942 self._to_bits(_flatten_and_remove_empty(trainer.slots)), 943 _flatten_and_remove_empty(slots)) --> 944 matched_slots, _ = fastmath.tree_unflatten( 945 self._from_bits(matched_flat_slots), 946 trainer.slots, copy_from_tree=[None, ()]) 947 trainer.slots = matched_slots 948 self._step = d['step']
File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree) 242 new_tree, rest = [], flat 243 for t in tree: --> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree) 245 new_tree.append(new_t) 246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree
File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree) 242 new_tree, rest = [], flat 243 for t in tree: --> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree) 245 new_tree.append(new_t) 246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree
File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/fastmath/numpy.py:239, in tree_unflatten(flat, tree, copy_from_tree) 216 def tree_unflatten(flat, tree, copy_from_tree=None): 217 """Unflatten a list into a tree given the tree shape as second argument. 218 219 Args: (...) 237 more were provided than the number of leaves of tree (useful for recursion). 238 """ --> 239 if copy_from_tree is not None and tree in copy_from_tree: 240 return tree, flat 241 if isinstance(tree, (list, tuple)):
File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:258, in _defer_to_unrecognized_arg.
TypeError: unsupported operand type(s) for ==: 'ArrayImpl' and 'tuple'" }