trl icon indicating copy to clipboard operation
trl copied to clipboard

`RichProgressCallback` would break model evaluation and prediction

Open eggry opened this issue 10 months ago • 1 comments

Hi! It's awesome to have a CLI for trl. However, there seems to be a problem with the newly introduced ProgressCallback. This issue affects both the evaluation and prediction stages.

To reproduce the issue, simply run

trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb --evaluation_strategy steps --eval_steps 1

this would leads to the following error message:

Traceback (most recent call last):
  File "**************/lib/python3.11/site-packages/trl/commands/scripts/sft.py", line 148, in <module>
    trainer.train()
  File "**************/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 360, in train
    output = super().train(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "**************/lib/python3.11/site-packages/transformers/trainer.py", line 1780, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "**************/lib/python3.11/site-packages/transformers/trainer.py", line 2193, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
  File "**************/lib/python3.11/site-packages/transformers/trainer.py", line 2577, in _maybe_log_save_evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "**************/lib/python3.11/site-packages/transformers/trainer.py", line 3365, in evaluate
    output = eval_loop(
             ^^^^^^^^^^
  File "**************/lib/python3.11/site-packages/transformers/trainer.py", line 3586, in evaluation_loop
    self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "**************/lib/python3.11/site-packages/transformers/trainer_callback.py", line 410, in on_prediction_step
    return self.call_event("on_prediction_step", args, state, control)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "**************/lib/python3.11/site-packages/transformers/trainer_callback.py", line 414, in call_event
    result = getattr(callback, event)(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "**************/lib/python3.11/site-packages/trl/trainer/utils.py", line 783, in on_prediction_step
    self.prediction_bar.update(self.prediction_task_id, advance=1, update=True)
  File "**************/lib/python3.11/site-packages/rich/progress.py", line 1425, in update
    task = self._tasks[task_id]
           ~~~~~~~~~~~^^^^^^^^^
KeyError: None

eggry avatar Mar 31 '24 07:03 eggry