deepxde icon indicating copy to clipboard operation
deepxde copied to clipboard

JAX checkpointing via Orbax

Open tttc3 opened this issue 11 months ago • 4 comments

Split from #1490 . Provides a very basic checkpointing functionality for flax based JAX models via Orbax. Essentially follows the simple example from flax documentation, without the use of the context manager.

As noted in #1490 , Orbax is currently imported lazily (only imported once first called) in a non PEP8 compliant manner, to avoid requiring Orbax as a package dependency. I feel that there should be a better way to do this, but I can't currently think of one. As such, I've marked this as a draft for now.

Any thought or suggestions are greatly appreciated!

tttc3 avatar Sep 25 '23 07:09 tttc3

From the information here https://orbax.readthedocs.io/en/latest/, Orbax is not a package, but rather a separate package for each functionality provided by the Orbax namespace. For checkpoint, we need orbax-checkpoint

lululxvi avatar Sep 25 '23 17:09 lululxvi

I've updated the import to only require orbax.checkpointing

tttc3 avatar Sep 25 '23 17:09 tttc3

As checkpoint is a very basic tool, I think it is ok to add it to the required package dependency.

lululxvi avatar Sep 25 '23 18:09 lululxvi

Added orbax-checkpoint as a top level import and a package wide dependency. Still have flax.training as a non top level import though.

tttc3 avatar Sep 25 '23 19:09 tttc3