axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

Update JAX API usage to latest version

Open ReNothingg opened this issue 5 months ago • 3 comments

We need to update the JAX API usage across the codebase to use the latest stable versions.

Changes Required

  1. Replace jax.tree_util with jax.tree:
  • Update all imports
  • Replace all usages of tree_util functions
  1. Update pytree registration:
  • Use register_pytree_with_keys instead of register_pytree_node
  • Update flattening/unflattening functions to use new key types
  • Update registration calls
  1. Update tree traversal operations:
  • Replace tree_map with new API version
  • Update tree manipulation functions
  • Modernize tree-related utilities

Files to Modify

Key files that need updates:

  • axlearn/common/struct.py
  • axlearn/common/utils.py
  • axlearn/common/metrics.py
  • axlearn/common/learner.py
  • Other files using JAX tree operations

Implementation Details

  1. For each file:
  • Scan for JAX tree API usage
  • Update imports
  • Replace deprecated functions
  • Update function signatures
  • Add type hints where missing
  1. Testing:
  • Run all tests with latest JAX
  • Verify no regressions
  • Check backward compatibility
  • Add new tests if needed

Success Criteria

  • All tests pass with latest JAX version
  • No functionality changes
  • Clean deprecation warnings
  • Improved type safety
  • Backward compatible changes

ReNothingg avatar Jul 25 '25 18:07 ReNothingg

Hi, the workflows need approval to run (GitHub Actions are pending). Can someone with write access approve and run them? @ruomingp pls

ReNothingg avatar Jul 25 '25 18:07 ReNothingg

Also please resolve any merge conflicts.

apghml avatar Jul 30 '25 02:07 apghml

This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the stale label or leave a comment.

github-actions[bot] avatar Oct 19 '25 02:10 github-actions[bot]