jaxline
jaxline copied to clipboard
How to access InMemory checkpoints and save to disk
Please kindly guide or advise on how to access InMemory checkpoints created by the InMemeoryCheckpointer, in order to save the checkpoints to disk.
Thank you.
+1. As a framework for training and evaluation, the lack of an out-of-the-box saving mechanism is curious to me.
I suspect the lack of a disk-based model checkpointer may be related to how the JAX / Haiku are "intentionally un-opinionated" on reading/writing models. So perhaps this is off-topic. But writing models to disk seems like a crucial part of the experimental cycle, and for newcomers to this project, the lack of support / discussion is a bit of a head scratcher.
And to make myself slightly useful and not purely complain:
@Alikerin , the DeepMind research GitHub has an example of writing and restoring a model checkpoint from disk. See:
https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/jax/experiment.py https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/jax/eval.py https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/jax/model_zoo.py
Essentially, they use model_zoo.py
to define their model as a haiku.Module
, run the training experiment, grab the parameters via jax.tree_map
, and write to disk via np.save(fp, (np_params, np_state))
. Then for inference, they load the objects via np.load
and couple those with model inference via @hk.transform_with_state
.