Add `compute_metrics` parameter for `GRPOTrainer`
Add compute_metrics parameter for GRPOTrainer
This is my first open-source PR contribution, I would greatly appreciate any feedback on areas for improvement. Please don't hesitate to suggest changes - I'm eager to learn and make this contribution as good as possible!
What does this PR do?
This PR adds compute_metrics parameter for GRPOTrainer, which is already supported by Trainer. We can compute accuracy or downstream eval metrics over the evaluation dataset
Fixes related issues
https://github.com/huggingface/trl/issues/3729 https://github.com/huggingface/trl/issues/2959
Changes Made
Added compute_metrics parameter to GRPOTrainer
File: trl/trainer/grpo_trainer.py
Added a new optional parameter after num_generations:
from transformers.trainer_utils import seed_worker, EvalLoopOutput
class GRPOTrainer(BaseTrainer):
"""
...
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
after the last eval batch to signal that the function needs to calculate and return the global summary
statistics rather than accumulating the batch-level statistics
...
"""
def __init__(
self,
...
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
...
):
...
super().__init__(
...
compute_metrics=compute_metrics,
...
)
Example Usage
def my_eval_function(eval_predict):
pass
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_func,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=my_eval_function,
)
# trainer.train() # evaluation during training
trainer.evaluate() # or directly evaluate your model.
More examples are available in this blog
Benefits
- Flexible: Users can choose their own function to evaluate their model during the trianing.
Who can review?
Any community member is welcome to provide feedback. As my first open-source contribution, I'm excited to learn - please don't hesitate to suggest any enhancements!
@kashif @qgallouedec @burtenshaw
thanks! I don't think it would work just like this though. Because in Trainer, compute_metrics is called only if compute_loss returns logits and labels. And in GRPO, it's not clear what would be the labels? Consequently, compute_loss doesn't support return_output
https://github.com/huggingface/trl/blob/f7ac9741d43528cf5050ba3ecd41cffec34c5a1a/trl/trainer/grpo_trainer.py#L1790-L1792
You can run this code, and see that my_eval_function is never called:
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only")
def my_eval_function(eval_predict):
print(eval_predict)
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c[0]["content"])) for c in completions]
trainer = GRPOTrainer(
model="Qwen/Qwen3-0.6B",
reward_funcs=reward_num_unique_chars,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
args=GRPOConfig(
eval_steps=2,
eval_strategy="steps",
),
compute_metrics=my_eval_function,
)
trainer.train()
Thanks for your reply!
Yes, the demo above given by me is not the minimal runnable. I just want to explain the how to add compute_metric to GPROTrain. Just like your comments, if want to normally use this function, we need the the compute_loss returns logits and labels, and some other changes to fit.
So the next thing I should do is to give a minimal runnable demo? Because I think add other code, like change the code compute_loss, is not reasonable. I need your suggestions to improve this PR.
By the way, I will firstly provide a minimal runnable demo.
Thanks