rl icon indicating copy to clipboard operation
rl copied to clipboard

[Feature Request] Rename Recorder and LogReward

Open jannessm opened this issue 1 year ago • 5 comments

Motivation

  1. When dealing with logging, I found it hard to grasp how to use different loggers and classes. Especially, the Recorder makes it difficult to grasp the idea behind it.

  2. For the LogReward class, I would love to make it more universal since it is actually just a class to log numeric values isn't it?

Solution

  1. rename the Recorder to LogValidationReward since it basically does just that.
  2. rename LogReward to LogScalar and give examples in the docs how to log the reward.

Alternatives

  1. Include a validation step in the trainer with additional hooks that allow for validation based actions (just log validation metrics).
  2. Add a LogScalar class from which LogReward inherits

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)

jannessm avatar Nov 26 '24 10:11 jannessm

Hello! Can I take this issue?

raresdan avatar Nov 26 '24 20:11 raresdan

Sure! We want to keep the old name, just raise a depreciation warning when it's built.

vmoens avatar Nov 26 '24 21:11 vmoens

And maybe an additional thought:

one should be able to define the aggregation function, e.g. "mean", "sum", ... one example for this could be https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#log

jannessm avatar Nov 27 '24 07:11 jannessm

maybe something like this?

class LogScalar(TrainerHookBase):
    def __init__(
        self,
        key: Union[str, tuple],
        logname: str,
        log_pbar: bool = False,
        reduce_fx: Union[str, Callable] = 'mean',
    ):
        self.logname = logname
        self.log_pbar = log_pbar
        self.key = key
        self.reduce_fx = reduce_fx if callable(reduce_fx) else getattr(torch, reduce_fx)

    def __call__(self, batch: TensorDictBase) -> Dict:
        if ("collector", "mask") in batch.keys(True):
            values = batch.get(self.key)[
                        batch.get(("collector", "mask"))
                    ]
        else:
            values = batch.get(self.key)
        
        value = self.reduce_fx(values.float()).item()

        return {
            self.logname: value,
            "log_pbar": self.log_pbar,
        }

    def register(self, trainer: Trainer, name: str = None):
        if name is None:
            name = f'log_{self.logname}'
        trainer.register_op("pre_steps_log", self)
        trainer.register_module(name, self)

jannessm avatar Nov 27 '24 08:11 jannessm

one should be able to define the aggregation function, e.g. "mean", "sum", ...

Makes sense, I'd split these things as separate PRs though

vmoens avatar Nov 27 '24 15:11 vmoens