agents icon indicating copy to clipboard operation
agents copied to clipboard

Checkpointer.save, can global_step become a kwarg?

Open tiamilani opened this issue 2 years ago • 0 comments

Hi, I'm not an expert in using TF agents, I've started to learn how to use this library quite recently, so I don't know if I'm just ignoring some implementation detail.

I noticed that the save() method of the utils.common.Checkpointer class asks for the global_step as a positional argument. The global_step argument is then passed to the save function of the _manager object. The piece of code I'm referring at is:

  def save(self, global_step: tf.Tensor,
           options: tf.train.CheckpointOptions = None):
    """Save state to checkpoint."""
    saved_checkpoint = self._manager.save(
        checkpoint_number=global_step, options=options)
    self._checkpoint_exists = True
    logging.info('%s', 'Saved checkpoint: {}'.format(saved_checkpoint))

Given that the CheckpoinManager.save() function accepts also None for the checkpoint_number kwarg, don't you think should be more correct to implement the save function as follows?

  def save(self, global_step: Optional[tf.Tensor] = None,
           options: tf.train.CheckpointOptions = None):
    """Save state to checkpoint."""
    saved_checkpoint = self._manager.save(
        checkpoint_number=global_step, options=options)
    self._checkpoint_exists = True
    logging.info('%s', 'Saved checkpoint: {}'.format(saved_checkpoint))

Thanks to this change the user would have the possibility to also use directly the checkpoint.save_counter mantained by the CheckpointManager class.

Am I missing some reason why it's mandatory to specify a global_step instead of giving the possibility to use the default CheckpointManager counter?

In case there is a positive feadback for this change I can also submit a pull request :) Thanks for your help in advance!

tiamilani avatar Jan 19 '23 14:01 tiamilani