rlpyt icon indicating copy to clipboard operation
rlpyt copied to clipboard

Log with Tensorboard/WandB?

Open drozzy opened this issue 4 years ago • 8 comments

Hi, guys. Love the idea for this. Read the paper, very interesting. Lots of nice design went into this, and I'm eager to try it.

How would I add some custom logging functionality? For example, if I want to log some metrics to tensorboard or wandb? Where would I insert that?

drozzy avatar May 21 '20 07:05 drozzy

Thanks for the kind words!

There was a recent PR for using tensorboard. We try to import it here:

https://github.com/astooke/rlpyt/blob/85d4e018a919118c6e42fac3e897aa346d84b9c5/rlpyt/utils/logging/context.py#L7

then in the call the logger context using the kwarg use_summary_writer=True (e.g. here's that line in example_1.py): https://github.com/astooke/rlpyt/blob/85d4e018a919118c6e42fac3e897aa346d84b9c5/examples/example_1.py#L50

I think if you use that, then pretty much everything that would otherwise get logged will also go into tensorboard. Let us know if that does enough!

astooke avatar May 21 '20 17:05 astooke

Nice. So if I wanted to log somewhere else I would just wrap it in this context?

with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):

drozzy avatar May 22 '20 00:05 drozzy

That's the idea, that the logger context starts up the logger (directories and such) and then closes it out when the work is done. Then the runner classes take care of the logging, if you want to look at the next level. I'm not sure the close-out is even necessary unless you're launching another experiment in the same process.

astooke avatar May 22 '20 05:05 astooke

Ok, I see that tf summary writer's add_scalar is used here as:

def record_tabular(key, val, *args, **kwargs):
    # if not _disabled and not _tabular_disabled:
    key = _tabular_prefix_str + str(key)
    _tabular.append((key, str(val)))
    if _tf_summary_writer is not None:
        _tf_summary_writer.add_scalar(key, val, _iteration)

So as you said, tensorboard is part of rlpyt now. But if I wanted to use another logger ( for example wandb.log()), would I need to override record_tabular?

It's ok if I can't do it, I'm just asking if there is such a possibility.

Thanks.

drozzy avatar May 22 '20 12:05 drozzy

To give a more concrete example, here is one way one framework (pytorch lightning) does it: https://pytorch-lightning.readthedocs.io/en/stable/loggers.html

from pytorch_lightning import Trainer
from pytorch_lightning import loggers
tb_logger = loggers.TensorBoardLogger('logs/')
trainer = Trainer(logger=tb_logger)

# or

comet_logger = loggers.CometLogger(save_dir='logs/')
trainer = Trainer(logger=comet_logger)

Here is how they allow you to implment a custom logger:

from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import LightningLoggerBase
class MyLogger(LightningLoggerBase):

    @rank_zero_only
    def log_hyperparams(self, params):
        # params is an argparse.Namespace
        # your code to record hyperparameters goes here
        pass

    @rank_zero_only
    def log_metrics(self, metrics, step):
        # metrics is a dictionary of metric names and values
        # your code to record metrics goes here
        pass

    def save(self):
        # Optional. Any code necessary to save logger data goes here
        pass

    @rank_zero_only
    def finalize(self, status):
        # Optional. Any code that needs to be run after training
        # finishes goes here
        pass

drozzy avatar May 22 '20 12:05 drozzy

I haven't used wandb, but I think some other people on the issues board have used wandb with rlpyt, so hopefully they can chime in?

The logger is the one piece that I really don't have much insight into--it's a direct copy from the old rllab. Hopefully it's straightforward to find the functionality and edit, but I'm afraid I can't be of much help on this one :/

astooke avatar May 22 '20 17:05 astooke

This is what we use for logging to wandb, basically creating a new class that inherits from existing runners:

class MinibatchRlEvalWandb(MinibatchRlEval):
    def log_diagnostics(self, itr, eval_traj_infos, eval_time):
        cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size
        self.wandb_info = {'cum_steps': cum_steps}
        super().log_diagnostics(itr, eval_traj_infos, eval_time)
        wandb.log(self.wandb_info)

    def _log_infos(self, traj_infos=None):
        if traj_infos is None:
            traj_infos = self._traj_infos
        if traj_infos:
            for k in traj_infos[0]:
                if not k.startswith("_"):
                    values = [info[k] for info in traj_infos]
                    logger.record_tabular_misc_stat(k,
                                                    values)
                    self.wandb_info[k + "Average"] = np.average(values)
                    self.wandb_info[k + "Median"] = np.median(values)

        if self._opt_infos:
            for k, v in self._opt_infos.items():
                logger.record_tabular_misc_stat(k, v)
                self.wandb_info[k] = np.average(v)
        self._opt_infos = {k: list() for k in self._opt_infos}  # (reset)

ankeshanand avatar Jun 02 '20 06:06 ankeshanand

Thanks. Good to have this as a reference!

P.S.: You guys should merge that into rlpyt.

drozzy avatar Jun 02 '20 08:06 drozzy