Improving snapshotting in JAX distributed experiments
Hi,
I have started adopting acme.jax.experiments.make_distributed_experiment for training distributed agents with LaunchPad. Right now, it seems that there is no way for me to configure snapshotting, which is a useful utility, and I was wondering if there is interest in improving this part from the Acme developers. I have a few proposals and I am happy to make a PR which includes some improvements. Here they are
Support snapshotting configuration.
Right now it's not possible to configure anything about the snapshotting. It would be great if we include options to specify
- where to store snapshots. It currently stores the snapshots in the workdir, without any subdirectory structure so you get workdir/20220601-00000/
. At the bare minimum by default, we should put them in a subdirectory so that we have workdir/snapshots/20220601-00000/ - how frequent to snapshot. For long training jobs (I work with pixels where training on humanoid control may take a day), the current default frequency is 5 minutes, and I end up with hundreds of snapshots that I don't really need.
To ensure backward compatibility so that users can opt-in to the new improvement. We can fall back to the same configuration as now if the user does not specify the configuration for snapshotting.
Improve the snapshotting format
Right now, JAX ModelToSnapshot embeds parameters as fixed constants in the tf.SavedModel. This is less general especially if users want to finetune the models afterward. Maybe we can save the parameters as variables in the SavedModel? I saw that the TF snapshotter is doing that and we can adopt a similar fashion, similar to the way that was done in https://github.com/google/jax/tree/main/jax/experimental/jax2tf
Add utilities for inference with snapshots.
I am not sure how users use the snapshots internally, perhaps with some internal libraries which allows easy loading of saved policies for evaluation/rollouts, but it would be great if there is at least some bare minimum way to do this in the OSS version. To me, the snapshots are really an excellent way to store models for analysis/evaluation/data generation afterwards. For example, I use the saved model for generating expert trajectories which I then use for imitation learning. It would be great if there are some helper functions that makes this work easily, or at least some examples of how users of Acme can take advantage of the saved snapshots.
Best, Yicheng
Any idea how to even load these tf saved parameters back into a Jax agent (e.g., from the dqn example)? I'm running into this issue, and so far isn't pretty confusing as the jax.saver api doesn't work out of the box here.
@rdevon agree that having some examples would be nice. Fortunately it's not that difficult to do once you get a hang of it.
The code for loading back the learner's state looks something like the following
# ...
dataset = fakes.transition_dataset(environment).batch(32).as_numpy_iterator()
learner = drq_builder.make_learner(jax.random.PRNGKey(0), networks, dataset)
print(learner._state.steps)
ckpt = tf.train.Checkpoint(**{"learner": acme.tf.savers.SaveableAdapter(learner)})
ckpt_path = '/tmp/ilax/checkpoints/learner'
mgr = tf.train.CheckpointManager(ckpt, ckpt_path, 1)
ckpt.restore(mgr.latest_checkpoint).assert_consumed()
print(learner._state.steps)
Essentially, the checkpoints are saved as standard TF checkpoints, the procedure for loading back the saved state is the same as if you are working with tensorflow. First recreate the agent with the exact parameters, then restore the checkpoint. The only additional thing you need to do is to wrap the learner with
acme.tf.savers.SaveableAdapter which handles restoring the state correctly.
Thanks so much! I was very close: I discovered after much digging through the API that adaptor, and was able to wrap the learner with it, but hadn't completely gotten the restore working. I'll try your solution and get back soon.
Thanks a bunch, there are some useful magic functions in your example. Here is a working example for the dqn demo:
import tensorflow as tf
import acme
from acme.testing import fakes
from acme import specs
from acme.agents.jax import dqn
from acme.agents.jax.dqn import losses
from helpers import make_atari_environment, make_dqn_atari_network
env = make_atari_environment()
env_spec = specs.make_environment_spec(env)
config = dqn.DQNConfig(
discount=0.99,
learning_rate=5e-5,
n_step=1,
epsilon=0.01,
target_update_period=2000,
min_replay_size=20_000,
max_replay_size=10_000_000,
samples_per_insert=8,
batch_size=32)
loss_fn = losses.QLearning(
discount=config.discount, max_abs_reward=1.)
dqn_builder = dqn.DQNBuilder(config, loss_fn=loss_fn)
network = make_dqn_atari_network(env_spec)
policy = dqn.behavior_policy(network)
dataset = fakes.transition_dataset(env).batch(32).as_numpy_iterator()
learner = dqn_builder.make_learner(jax.random.PRNGKey(0), network, dataset, None, env_spec)
s0 = learner._state.params.copy()
ckpt = tf.train.Checkpoint(learner=acme.tf.savers.SaveableAdapter(learner))
ckpt_path = '/checkpoints/20220531-124850/checkpoints/learner/'
mgr = tf.train.CheckpointManager(ckpt, ckpt_path, 1)
ckpt.restore(mgr.latest_checkpoint).assert_consumed()
s1 = learner._state.params.copy()
for k, v in s0.items():
for k_, v_ in v.items():
assert (s0[k][k_] - s1[k][k_]).sum() != 0, f'New parameters are the same as old {k}.{k_}'
print(f'{k}.{k_} parameters successfully updated!')
I think @ethanluoyc proposals are great! Have there been any interest from the acme team to go on that direction?
I'm wondering whether acme is a good library to do my projects with, but I had to spent too much time yesterday to hack my way into saving parameters, so it doesn't seem super user friendly. The suggested changes would make things easier for sure.
If you want to restore the checkpointed state, here's some function that may be useful.
from acme import core
from acme.tf import savers
import tensorflow as tf
class _MockSaveable(core.Saveable):
def __init__(self, state=None):
self._state = state
def save(self):
return self._state
def restore(self, state):
self._state = state
def restore_state_from_checkpoint(checkpoint_dir: str, key: str = 'learner'):
"""Restore state from a checkpoint."""
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
saveable = _MockSaveable()
checkpointable_wrapped = savers.SaveableAdapter(saveable)
checkpoint = tf.train.Checkpoint(**{key: checkpointable_wrapped})
checkpoint.restore(latest_checkpoint)
return saveable.save()
Adapted from: https://github.com/deepmind/alphastar/blob/main/alphastar/modules/common.py
This would allow you to restore the state saved during run_experiment.
@alonfnt If you are looking for a library that you can use out of the box, then I would say that some parts do not work very well. I personally use acme for my projects but I have to fork some part of the code to work nicely with my workflow. Some of my code is at https://github.com/ethanluoyc/magi/tree/develop which you might find useful.
Thanks @ethanluoyc for the response. I wasn't able to get it to work, so I switched to using stable-baseline3 which provides a more user-friendly experience (The saving and loading is handled by model.save() and model.load(), eliminating the workaround I was trying with acme). Although magi looks promising, I'll give it a shot when I'm not as frustrated with acme as I am now :)