trlx icon indicating copy to clipboard operation
trlx copied to clipboard

Support logging for non-scalar metrics

Open g-simmons opened this issue 2 years ago • 2 comments

🚀 The feature, motivation, and pitch

AccelerateRLTrainer.evaluate() logs a table of generated eval outputs and metrics to the metrics tracker.

If I understand correctly, only scalar metrics are currently supported.

This feature would allow non-scalar metrics to be logged.

Use cases:

Allow passthrough logging of non-scalar prompt metadata:

    def eval_metrics(samples, prompts, outputs, **kwargs):
        return kwargs

    trlx.train(
        model_name,
        config=config,
        samples=train_samples,  # type: ignore
        rewards=train_rewards if training_method == "ilql" else None,  # type: ignore
        eval_prompts=eval_eval_prompts,  # type: ignore
        metric_fn=eval_metrics,
    ).model

Implementation suggestion:

Modify the mean_metrics calculation (below) to only calculate means for values that can be successfully cast to float tensors.

https://github.com/CarperAI/trlx/blob/0dce99d96b7d70b6a9114129d8e38bf6c80eb653/trlx/trainer/accelerate_base_trainer.py#L428-L430

Alternatives

No response

Additional context

No response

g-simmons avatar Jun 02 '23 22:06 g-simmons

hey, looks good. Do you want me to assign you?

LouisCastricato avatar Jun 03 '23 19:06 LouisCastricato

Don't have a ton of time to work on it currently :/

I just did a simple try/except to get around it for now, if someone wants to develop it further that would be great.

mean_metrics = {}

for k, xs in metrics.items():
    try:
        mean_metrics[f"metrics/{k}{sweep_suffix}"] = torch.as_tensor(xs).mean(-1).item()
    except Exception:
        logger.warning(f"Metric {k} is not a scalar, skipping")
        continue

g-simmons avatar Jun 03 '23 22:06 g-simmons