acme
acme copied to clipboard
Loading checkpoint does not work
The add_uid variable is not passed to the agent and learner, and is set to default to true in the Snapshotter and Checkpointer.
Therefore when passing to the agent the checkpoint path say:
checkpoint_subpath = "<cwd>/log/fed68b00-b8cb-11eb-812c-bba6f8f69ddd"
a new folder will be created inside it with a new identifier and the previous checkpoints are not loaded.
To reproduce run test_dqn.py twice, the second time set checkpoint_subpath pointing to the previous log path and pass it to the agent
Not a real solution, but the (hacky) workaround I'm using to restore checkpoints is this:
agent._learner._checkpointer._checkpoint.restore(ckpt)
where agent
is an instance of, say, acme.agents.tf.dmpo.DistributionalMPO
and ckpt
looks like:
ckpt = '~/acme/9ae2305e-eab2-11eb-8007-b8ca3a99639d/checkpoints/dmpo_learner/ckpt-25'
Following up on this. I'm trying to do this for an R2D2-like agent but it doesn't seem the work. Maybe it's because I have a custom learner?
As a note, I'm trying to load parameters from a distributed agent into a local agent for analysis.
I have the following files in my directory
checkpoint
ckpt-36.data-00000-of-00002
ckpt-36.data-00001-of-00002
ckpt-36.index
ckpt-37.data-00000-of-00002
ckpt-37.data-00001-of-00002
ckpt-37.index
and I'm trying to load ckpt-37.index
directly as
ckpt_path = `ckpt-37.index`
agent._checkpointer._checkpoint.restore(ckpt_path)
I get the following print-out:
WARNING:tensorflow:From /home/wcarvalh/miniconda3/envs/acmejax/lib/python3.9/site-packages/tensorflow/python/training/tracking/util.py:1365: NameBasedSaverStatus.__init__ (from tensorflow.python.training.tracking.util) is deprecated and will be removed in a future version.
Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
W0503 13:44:47.846328 140400815281984 deprecation.py:337] From /home/wcarvalh/miniconda3/envs/acmejax/lib/python3.9/site-packages/tensorflow/python/training/tracking/util.py:1365: NameBasedSaverStatus.__init__ (from tensorflow.python.training.tracking.util) is deprecated and will be removed in a future version.
Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
I might be wrong, but can you try ckpt_path = "ckpt-37"
instead (i.e., drop the .index
part)?
Thank you for the suggestion! I tried this and it didn't work :(
Any news on how to properly load?
Worked for me by just turning off add_uid, eg with a small change to the run_r2d2.py example:
return experiments.ExperimentConfig(
builder=r2d2.R2D2Builder(config),
network_factory=r2d2.make_atari_networks,
environment_factory=environment_factory,
seed=FLAGS.seed,
checkpointing=experiments.CheckpointingConfig(add_uid=False), # turn off the uid
max_num_actor_steps=FLAGS.num_steps)
New checkpoints are then saved to the checkpoints subpath directly, and if a checkpoint is already there it is restored automatically prior to training. I'm guessing this is more or less the way it was intended to be done.