transformers
transformers copied to clipboard
Use metrics that consider the input as well as the (predicted, reference) tuple in the Trainer
Feature request
Allow the compute_metrics() in the Trainer to take into account the original input in addition to the predictions and labels.
Motivation
It is currently possible to pass a custom compute_metrics() to the Trainer for evaluation. An example is
def compute_metrics(eval_preds):
metric = evaluate.load("glue", "mrpc")
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
However, the compute_metrics seems to be constrained only to receive a logit, label tuple.
This is insufficient for some metrics that also depend on the original sentence. An example is SARI, which is currently implemented in the evaluate library.
Being unable to use the original input in the evaluation makes it impossible to use the Trainer for some seq2seq tasks, e.g. simplification.
Your contribution
If the request is accepted, I will try to contribute with a PR.