flax
flax copied to clipboard
Deprecation Warnings with orbax 0.5.3
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Debian 12
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
: flax 0.8.1, orbax-checkpoint 0.5.3 - Python version: 3.11.4
Problem you have encountered:
I followed save and load checkpoints tutorial and I got deprecation warnings. Although checkpoints were saved correctly, it would be great if the latest correct way of saving/loading Flax TraningState
is documented in the tutorial.
What you expected to happen:
no warnings
Logs, error messages, etc:
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before May 1st, 2024.
Steps to reproduce:
Just follow save and load checkpoints tutorial
Any thoughts on how to checkpoint a Flax TrainState
with the new CheckpointManager
API? I gave it a go on StackOverflow, but without success.
Hi - one of the Orbax team member replied your StackOverflow thread, please take a look.
Regarding the SaveArgs.aggregate is deprecated
warning, it seems that internally aggregate=True
will happen whenever Orbax tries to save an empty node like optax.EmptyNode()
(which is part of the TrainState
as optimizer state).
Orbax team would will work on an refactoring that removes the use of aggregate
internally. Meanwhile, the whole TrainState
will still be saved correctly despite the deprecation warning, so it probably would not affect your use.
I will make another PR that addresses the deprecated CheckpointManager
API warning in the Flax Orbax guide.
Hi,
I am still receiving this warning WARNING:absl:SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before May 1st, 2024. If your Pytree has empty ([], {}, None) values then use PyTreeCheckpointHandler(..., write_tree_metadata=True, ...) or use StandardCheckpointHandler to avoid TypeHandler Registry error. Please note that PyTreeCheckpointHandler.write_tree_metadata default value is already set to T
. Has the fix been implemented?
Below is my code:
#checkpointing
check_options = ocp.CheckpointManagerOptions(max_to_keep=5, create=True)
check_path = Path(os.getcwd(), out_dir, 'checkpoint')
checkpoint_manager = ocp.CheckpointManager(check_path, options=check_options, item_names=('state', 'metadata'))
checkpoint_manager.save(
step=iter_num,
args=ocp.args.Composite(
state=ocp.args.StandardSave(state),
metadata=ocp.args.JsonSave((model_args, iter_num, best_val_loss, losses['val'].item(), config))))
orbax-checkpoint version: 0.5.9