transformers
transformers copied to clipboard
Make evaluation loss accessible inside compute_metrics() in Trainer
Feature request
Hi, I am requesting the feature to make evaluation loss accessible inside compute_metrics() within the Trainer class, this will enable users to log loss dependent metrics during training, in my case I want to track perplexity. By making the pre-computed validation loss available users can accurately compute these metrics without having to recompute them from scratch.
The existing solution computes perplexity score post training by running the model through a validation batch via trainer.evaluate(), this only provides a single score on a single validation batch after all training is done.
Motivation
Why other solutions to track perplexity during training via compute_metrics do not work effectively :
-
Using
evaluatelibrary to load perplexity : when using Trainer, it reloads the model from the hub at each evaluation which is inefficient. Furthermore, the implementation is not that flexible, for instance it assumes thatlabelsare exactly equal toinput_idswhich is not always the case, for instance I’m running instruction finetuning where the labels are the response part of theinput_ids. -
Computing perplexity direclty using
EvalPrediction's logits and labels to compute the loss again inside ofcompute_metricsfunction would be inaccurate since we do not have access toattention_maskof the validation batches, so we can’t compute proper loss by masking padding tokens. -
Why make loss accessible in
compute_metricsalong with logits and labels :- The evaluation loss is computed by default earlier in the evaluation loop, and
compute_metricsis executed right after. - I’m assured that the loss used is accurate and consistent and thus the reported perplexity is reliable. Furthermore, the loss can be used to compute other metrics other than perplexity.
- I believe this feature would be useful since there are similar request In the community like this and this.
- The evaluation loss is computed by default earlier in the evaluation loop, and
Your contribution
I would be happy to implement this and get your feedback on whether it is the right approach before opening PR 😊.
cc @muellerzr @SunMarc
Would love to see this feature added. Not having access to the loss inside of compute_metrics() makes it rather difficult to calculate metrics like perplexity.
Hey @Manalelaidouni ! I Thanks for the detailed report. Adding loss to the compute_metrics can indeed make sense ! Just to be sure, why the metrics f"{metric_key_prefix}_loss" returned at the end of evaluation_loop is not enough in your case ? If you want to calculate the perplexity on your entire evaluation, that should be enough no ? If your goal is to calculate the perplexity for each batch size (e.g using batch_eval_metrics =True), adding loss to compute_metrics will indeed be needed.
if isinstance(all_losses, list) and all_losses:
metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()
cc @muellerzr
Thank you @SunMarc!
I see where you're coming from, but even if I wanted to calculate perplexity for the entire evaluation set (when batch_eval_metrics=False), using f"{metric_key_prefix}_loss" to compute metrics like perplexity (or any loss dependent metric) would require overriding the evaluation_loop method and adding the logic at the end, which is not very user friendly.
I believe it would be useful if the evaluation loss was directly accessible within compute_metrics() in both cases — when computing per-batch metrics or metrics for the entire evaluation set — it would make computing loss dependent metrics more straightforward and ensure they are accurate
If you’re interested here is a solution I'm thinking of :
users can choose to include the loss inside of compute_metrics by passing a bool flag like return_loss:
- If the
return_lossis set to False or not passed as argument,use_lossflag inside ofevaluation_loop()will be disabled and backward compatibility is ensured. - If user sets
return_lossto True, thenuse_lossflag is enabled and loss is included in the EvalPrediction — we’ll make the loss optional just like inputs so thatEvalPredictionworks as expected with and without the loss:
def compute_metrics(eval_preds, return_loss=True):
logits, labels, losses = eval_preds
# user can compute metrics as they wish with returned values
pass
We can check if return_loss is passed as argument and if it's enabled like this :
def is_loss_included(metrics_func):
# checking if the flag is passed as argument and if it's enabled
arguments = inspect.signature(metrics_func).parameters
if "return_loss" in arguments and arguments["return_loss"].default:
return True
return False
Inside of evaluation_loop() :
if (self.compute_metrics is not None and all_preds is not None and all_labels is not None and not self.args.batch_eval_metrics):
use_loss = is_loss_included(self.compute_metrics)
if use_loss :
if self.args.include_inputs_for_metrics:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs, losses=all_losses))
else:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels, losses=all_losses))
else:
if self.args.include_inputs_for_metrics:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs))
else:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
making the loss optional just like inputs so that EvalPrediction works as expected with and without the loss:
class EvalPrediction:
def __init__(
self,
predictions: Union[np.ndarray, Tuple[np.ndarray]],
label_ids: Union[np.ndarray, Tuple[np.ndarray]],
inputs: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
losses=None):
self.predictions = predictions
self.label_ids = label_ids
self.inputs = inputs
self.losses = losses
def __iter__(self):
if self.losses is not None:
if self.inputs is not None:
return iter((self.predictions, self.label_ids, self.inputs, self.losses))
else:
return iter((self.predictions, self.label_ids, self.losses))
else:
if self.inputs is not None:
return iter((self.predictions, self.label_ids, self.inputs))
else:
return iter((self.predictions, self.label_ids))
def __getitem__(self, idx):
if self.losses is not None:
if idx < 0 or idx > 3:
raise IndexError("tuple index out of range")
if idx == 2 and self.inputs is None:
raise IndexError("tuple index out of range")
if idx == 3:
return self.losses
else:
if idx < 0 or idx > 2:
raise IndexError("tuple index out of range")
if idx == 2 and self.inputs is None:
raise IndexError("tuple index out of range")
if idx == 0:
return self.predictions
elif idx == 1:
return self.label_ids
elif idx == 2:
return self.inputs
Additional option for users to enable this :
- by adding the
return_lossflag when defining thecompute_metricsfunction as I descrived above, - or by adding include_loss_for_metrics when defining training arguments, just like
include_inputs_for_metrics, we can simply addinclude_loss_for_metricsin the TrainingArguments class, and useif use_loss or self.args.include_loss_for_metrics:inside of evaluation_loop() (in the second code snippet) instead of just ifuse_loss.
I’m looking forward to your feedback 😊
I see where you're coming from, but even if I wanted to calculate perplexity for the entire evaluation set (when batch_eval_metrics=False), using f"{metric_key_prefix}_loss" to compute metrics like perplexity (or any loss dependent metric) would require overriding the evaluation_loop method and adding the logic at the end, which is not very user friendly.
Oh yeah indeed. You can always compute the perplexity at the end of the trainer.evaluate() from metrics but I guess it can be quite a pain if you also want this on your tracker (wandb, ...).
I think it will be better to follow was what done for include_inputs_for_metrics since it is a similar case. If you have time, would you like to open a PR to add this ? You can check this PR for reference. Also, in order to not call the same line multiple time, you can just create dict of kwargs depending on the args and pass it like that self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels, **kwargs))
Thanks for the quick feedback! I'd be happy to do so, I'll open a PR tomorrow to work on this and thanks again for the guidance.