trlx
trlx copied to clipboard
Support logging for non-scalar metrics
🚀 The feature, motivation, and pitch
AccelerateRLTrainer.evaluate() logs a table of generated eval outputs and metrics to the metrics tracker.
If I understand correctly, only scalar metrics are currently supported.
This feature would allow non-scalar metrics to be logged.
Use cases:
Allow passthrough logging of non-scalar prompt metadata:
def eval_metrics(samples, prompts, outputs, **kwargs):
return kwargs
trlx.train(
model_name,
config=config,
samples=train_samples, # type: ignore
rewards=train_rewards if training_method == "ilql" else None, # type: ignore
eval_prompts=eval_eval_prompts, # type: ignore
metric_fn=eval_metrics,
).model
Implementation suggestion:
Modify the mean_metrics calculation (below) to only calculate means for values that can be successfully cast to float tensors.
https://github.com/CarperAI/trlx/blob/0dce99d96b7d70b6a9114129d8e38bf6c80eb653/trlx/trainer/accelerate_base_trainer.py#L428-L430
Alternatives
No response
Additional context
No response
hey, looks good. Do you want me to assign you?
Don't have a ton of time to work on it currently :/
I just did a simple try/except to get around it for now, if someone wants to develop it further that would be great.
mean_metrics = {}
for k, xs in metrics.items():
try:
mean_metrics[f"metrics/{k}{sweep_suffix}"] = torch.as_tensor(xs).mean(-1).item()
except Exception:
logger.warning(f"Metric {k} is not a scalar, skipping")
continue