trl icon indicating copy to clipboard operation
trl copied to clipboard

Add `compute_metrics` parameter for `GRPOTrainer`

Open colinzhaoxp opened this issue 5 months ago • 3 comments

Add compute_metrics parameter for GRPOTrainer

This is my first open-source PR contribution, I would greatly appreciate any feedback on areas for improvement. Please don't hesitate to suggest changes - I'm eager to learn and make this contribution as good as possible!

What does this PR do?

This PR adds compute_metrics parameter for GRPOTrainer, which is already supported by Trainer. We can compute accuracy or downstream eval metrics over the evaluation dataset

Fixes related issues

https://github.com/huggingface/trl/issues/3729 https://github.com/huggingface/trl/issues/2959

Changes Made

Added compute_metrics parameter to GRPOTrainer File: trl/trainer/grpo_trainer.py

Added a new optional parameter after num_generations:

from transformers.trainer_utils import seed_worker, EvalLoopOutput

class GRPOTrainer(BaseTrainer):
    """
    ...
    compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
    The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
    a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
    `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
    after the last eval batch to signal that the function needs to calculate and return the global summary
    statistics rather than accumulating the batch-level statistics
    ...
    """
    def __init__(
        self,
        ...
        compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
        ...
    ):
    ...
    super().__init__(
        ...
        compute_metrics=compute_metrics,
        ...
    )

Example Usage

def my_eval_function(eval_predict):
       pass

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_func,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=my_eval_function,
)
# trainer.train() # evaluation during training
trainer.evaluate() # or directly evaluate your model.

More examples are available in this blog

Benefits

  • Flexible: Users can choose their own function to evaluate their model during the trianing.

Who can review?

Any community member is welcome to provide feedback. As my first open-source contribution, I'm excited to learn - please don't hesitate to suggest any enhancements!

colinzhaoxp avatar Nov 17 '25 14:11 colinzhaoxp

@kashif @qgallouedec @burtenshaw

colinzhaoxp avatar Nov 17 '25 14:11 colinzhaoxp

thanks! I don't think it would work just like this though. Because in Trainer, compute_metrics is called only if compute_loss returns logits and labels. And in GRPO, it's not clear what would be the labels? Consequently, compute_loss doesn't support return_output

https://github.com/huggingface/trl/blob/f7ac9741d43528cf5050ba3ecd41cffec34c5a1a/trl/trainer/grpo_trainer.py#L1790-L1792

You can run this code, and see that my_eval_function is never called:

from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig

dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only")

def my_eval_function(eval_predict):
    print(eval_predict)


# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
    return [len(set(c[0]["content"])) for c in completions]

trainer = GRPOTrainer(
    model="Qwen/Qwen3-0.6B",
    reward_funcs=reward_num_unique_chars,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    args=GRPOConfig(
        eval_steps=2,
        eval_strategy="steps",
    ),
    compute_metrics=my_eval_function,
)
trainer.train()

qgallouedec avatar Nov 21 '25 22:11 qgallouedec

Thanks for your reply!

Yes, the demo above given by me is not the minimal runnable. I just want to explain the how to add compute_metric to GPROTrain. Just like your comments, if want to normally use this function, we need the the compute_loss returns logits and labels, and some other changes to fit.

So the next thing I should do is to give a minimal runnable demo? Because I think add other code, like change the code compute_loss, is not reasonable. I need your suggestions to improve this PR.

By the way, I will firstly provide a minimal runnable demo.

Thanks

colinzhaoxp avatar Nov 22 '25 05:11 colinzhaoxp