agents
agents copied to clipboard
Checkpointer.save, can global_step become a kwarg?
Hi, I'm not an expert in using TF agents, I've started to learn how to use this library quite recently, so I don't know if I'm just ignoring some implementation detail.
I noticed that the save()
method of the utils.common.Checkpointer
class asks for the global_step
as a positional argument.
The global_step
argument is then passed to the save
function of the _manager
object.
The piece of code I'm referring at is:
def save(self, global_step: tf.Tensor,
options: tf.train.CheckpointOptions = None):
"""Save state to checkpoint."""
saved_checkpoint = self._manager.save(
checkpoint_number=global_step, options=options)
self._checkpoint_exists = True
logging.info('%s', 'Saved checkpoint: {}'.format(saved_checkpoint))
Given that the CheckpoinManager.save()
function accepts also None
for the checkpoint_number
kwarg, don't you think should be more correct to implement the save function as follows?
def save(self, global_step: Optional[tf.Tensor] = None,
options: tf.train.CheckpointOptions = None):
"""Save state to checkpoint."""
saved_checkpoint = self._manager.save(
checkpoint_number=global_step, options=options)
self._checkpoint_exists = True
logging.info('%s', 'Saved checkpoint: {}'.format(saved_checkpoint))
Thanks to this change the user would have the possibility to also use directly the checkpoint.save_counter
mantained by the CheckpointManager
class.
Am I missing some reason why it's mandatory to specify a global_step instead of giving the possibility to use the default CheckpointManager
counter?
In case there is a positive feadback for this change I can also submit a pull request :) Thanks for your help in advance!