trl
trl copied to clipboard
`RichProgressCallback` would break model evaluation and prediction
Hi! It's awesome to have a CLI for trl
.
However, there seems to be a problem with the newly introduced ProgressCallback
. This issue affects both the evaluation and prediction stages.
To reproduce the issue, simply run
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb --evaluation_strategy steps --eval_steps 1
this would leads to the following error message:
Traceback (most recent call last):
File "**************/lib/python3.11/site-packages/trl/commands/scripts/sft.py", line 148, in <module>
trainer.train()
File "**************/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 360, in train
output = super().train(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "**************/lib/python3.11/site-packages/transformers/trainer.py", line 1780, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "**************/lib/python3.11/site-packages/transformers/trainer.py", line 2193, in _inner_training_loop
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
File "**************/lib/python3.11/site-packages/transformers/trainer.py", line 2577, in _maybe_log_save_evaluate
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "**************/lib/python3.11/site-packages/transformers/trainer.py", line 3365, in evaluate
output = eval_loop(
^^^^^^^^^^
File "**************/lib/python3.11/site-packages/transformers/trainer.py", line 3586, in evaluation_loop
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "**************/lib/python3.11/site-packages/transformers/trainer_callback.py", line 410, in on_prediction_step
return self.call_event("on_prediction_step", args, state, control)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "**************/lib/python3.11/site-packages/transformers/trainer_callback.py", line 414, in call_event
result = getattr(callback, event)(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "**************/lib/python3.11/site-packages/trl/trainer/utils.py", line 783, in on_prediction_step
self.prediction_bar.update(self.prediction_task_id, advance=1, update=True)
File "**************/lib/python3.11/site-packages/rich/progress.py", line 1425, in update
task = self._tasks[task_id]
~~~~~~~~~~~^^^^^^^^^
KeyError: None