deepxde
deepxde copied to clipboard
JAX checkpointing via Orbax
Split from #1490 . Provides a very basic checkpointing functionality for flax based JAX models via Orbax. Essentially follows the simple example from flax documentation, without the use of the context manager.
As noted in #1490 , Orbax is currently imported lazily (only imported once first called) in a non PEP8 compliant manner, to avoid requiring Orbax as a package dependency. I feel that there should be a better way to do this, but I can't currently think of one. As such, I've marked this as a draft for now.
Any thought or suggestions are greatly appreciated!
From the information here https://orbax.readthedocs.io/en/latest/, Orbax is not a package, but rather a separate package for each functionality provided by the Orbax namespace. For checkpoint, we need orbax-checkpoint
I've updated the import to only require orbax.checkpointing
As checkpoint is a very basic tool, I think it is ok to add it to the required package dependency.
Added orbax-checkpoint as a top level import and a package wide dependency. Still have flax.training
as a non top level import though.