flax icon indicating copy to clipboard operation
flax copied to clipboard

Deprecation Warnings with orbax 0.5.3

Open lucidfrontier45 opened this issue 1 year ago • 3 comments

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Debian 12
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax 0.8.1, orbax-checkpoint 0.5.3
  • Python version: 3.11.4

Problem you have encountered:

I followed save and load checkpoints tutorial and I got deprecation warnings. Although checkpoints were saved correctly, it would be great if the latest correct way of saving/loading Flax TraningState is documented in the tutorial.

What you expected to happen:

no warnings

Logs, error messages, etc:

WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before May 1st, 2024.

Steps to reproduce:

Just follow save and load checkpoints tutorial

lucidfrontier45 avatar Feb 22 '24 03:02 lucidfrontier45

Any thoughts on how to checkpoint a Flax TrainState with the new CheckpointManager API? I gave it a go on StackOverflow, but without success.

hylkedonker avatar Feb 22 '24 16:02 hylkedonker

Hi - one of the Orbax team member replied your StackOverflow thread, please take a look.

Regarding the SaveArgs.aggregate is deprecated warning, it seems that internally aggregate=True will happen whenever Orbax tries to save an empty node like optax.EmptyNode() (which is part of the TrainState as optimizer state). Orbax team would will work on an refactoring that removes the use of aggregate internally. Meanwhile, the whole TrainState will still be saved correctly despite the deprecation warning, so it probably would not affect your use.

I will make another PR that addresses the deprecated CheckpointManager API warning in the Flax Orbax guide.

IvyZX avatar Feb 28 '24 21:02 IvyZX

Hi,

I am still receiving this warning WARNING:absl:SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before May 1st, 2024. If your Pytree has empty ([], {}, None) values then use PyTreeCheckpointHandler(..., write_tree_metadata=True, ...) or use StandardCheckpointHandler to avoid TypeHandler Registry error. Please note that PyTreeCheckpointHandler.write_tree_metadata default value is already set to T . Has the fix been implemented?

Below is my code:

#checkpointing
check_options = ocp.CheckpointManagerOptions(max_to_keep=5, create=True)
check_path = Path(os.getcwd(), out_dir, 'checkpoint')
checkpoint_manager = ocp.CheckpointManager(check_path, options=check_options, item_names=('state', 'metadata'))
checkpoint_manager.save(
                    step=iter_num,
                    args=ocp.args.Composite(
                        state=ocp.args.StandardSave(state),
                        metadata=ocp.args.JsonSave((model_args, iter_num, best_val_loss, losses['val'].item(), config))))

orbax-checkpoint version: 0.5.9

jiagaoxiang avatar Apr 24 '24 00:04 jiagaoxiang