axlearn
axlearn copied to clipboard
Update JAX API usage to latest version
We need to update the JAX API usage across the codebase to use the latest stable versions.
Changes Required
- Replace
jax.tree_utilwithjax.tree:
- Update all imports
- Replace all usages of tree_util functions
- Update pytree registration:
- Use
register_pytree_with_keysinstead ofregister_pytree_node - Update flattening/unflattening functions to use new key types
- Update registration calls
- Update tree traversal operations:
- Replace
tree_mapwith new API version - Update tree manipulation functions
- Modernize tree-related utilities
Files to Modify
Key files that need updates:
axlearn/common/struct.pyaxlearn/common/utils.pyaxlearn/common/metrics.pyaxlearn/common/learner.py- Other files using JAX tree operations
Implementation Details
- For each file:
- Scan for JAX tree API usage
- Update imports
- Replace deprecated functions
- Update function signatures
- Add type hints where missing
- Testing:
- Run all tests with latest JAX
- Verify no regressions
- Check backward compatibility
- Add new tests if needed
Success Criteria
- All tests pass with latest JAX version
- No functionality changes
- Clean deprecation warnings
- Improved type safety
- Backward compatible changes
Hi, the workflows need approval to run (GitHub Actions are pending). Can someone with write access approve and run them? @ruomingp pls
Also please resolve any merge conflicts.
This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the stale label or leave a comment.