jaxline icon indicating copy to clipboard operation
jaxline copied to clipboard

How to access InMemory checkpoints and save to disk

Open Alikerin opened this issue 3 years ago • 2 comments

Please kindly guide or advise on how to access InMemory checkpoints created by the InMemeoryCheckpointer, in order to save the checkpoints to disk.

Thank you.

Alikerin avatar Mar 12 '21 08:03 Alikerin

+1. As a framework for training and evaluation, the lack of an out-of-the-box saving mechanism is curious to me.

I suspect the lack of a disk-based model checkpointer may be related to how the JAX / Haiku are "intentionally un-opinionated" on reading/writing models. So perhaps this is off-topic. But writing models to disk seems like a crucial part of the experimental cycle, and for newcomers to this project, the lack of support / discussion is a bit of a head scratcher.

almostimplemented avatar Mar 29 '22 12:03 almostimplemented

And to make myself slightly useful and not purely complain:

@Alikerin , the DeepMind research GitHub has an example of writing and restoring a model checkpoint from disk. See:

https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/jax/experiment.py https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/jax/eval.py https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/jax/model_zoo.py

Essentially, they use model_zoo.py to define their model as a haiku.Module, run the training experiment, grab the parameters via jax.tree_map, and write to disk via np.save(fp, (np_params, np_state)). Then for inference, they load the objects via np.load and couple those with model inference via @hk.transform_with_state.

almostimplemented avatar Mar 29 '22 12:03 almostimplemented